Oxford Nanopore Sequencing Benchmark¶

This report presents a benchmark of SNVs, indels, and SVs, and a characterisation of the ONT dataset used for this benchmark.

Methods¶

Data Processing¶

  • Basecalling: wf-basecalling v1.1.7

  • Alignment and Variant Calling: wf-human-variation v2.1.0

Quality Control Tools¶

  • NanoPlot: 1.42.0

    • Generates summary statistics for each sample and creates visualizations of QC metrics for sequencing summaries and aligned BAM files.
  • NanoComp: 1.23.1

    • Compares multiple sequencing runs and generates comparative plots.
  • mosdepth: 0.3.3

    • Calculates sequencing depth across the human genome for each sample.
  • rtg-tools: 3.12.1

    • Performs performs variant comparison against a truth dataset.
  • SURVIVOR: 1.0.7

    • Performs merging of vcf files to compare SVs within a sample and among populations/samples.
In [1]:
# Standard library imports
import os
import glob
import gzip
import pickle
import logging
import re
from collections import defaultdict
from dataclasses import dataclass, field
from pathlib import Path
from statistics import mean
from typing import Any, DefaultDict, Dict, List, Literal, Optional, Set, Tuple, Union

# Third-party imports
import numpy as np
import polars as pl
import matplotlib.gridspec as gridspec
import matplotlib.pyplot as plt
import matplotlib.ticker as mticker
import seaborn as sns
import statsmodels.api as sm
import pysam
from matplotlib.lines import Line2D
from matplotlib.patches import Patch
from scipy import stats
from scipy.optimize import curve_fit
from scipy.spatial.distance import cdist
from sklearn import metrics
from statsmodels.stats.multitest import multipletests

# Seaborn settings
sns.set_style("whitegrid")
sns.set_context("paper")
sns.set_palette("colorblind")

# Logging settings
logging.basicConfig(
    level=logging.INFO,
    format="%(name)s - %(levelname)s - %(message)s",
    force=True,
)
logger = logging.getLogger(__name__)

Sequencing Quality Control¶

Aggregate table of the QC metrics from NanoStats for both singleplexed and multiplexed samples, from the aligned .cram files produced by wf-human-variation.

Unless otherwise specified, subsequent plots and statistics include only samples basecalled with the sup algorithm.

In [2]:
def get_default_column_types() -> Dict[str, str]:
    """
    Define the data types for the columns in the DataFrame.

    Returns:
        Dict[str, str]: Mapping of column names to their types
    """
    return {
        "multiplexing": "category",
        "basecall": "category",
        "anonymised_sample": "category",
        "number_of_reads": "numeric",
        "number_of_bases": "numeric",
        "number_of_bases_aligned": "numeric",
        "fraction_bases_aligned": "numeric",
        "mean_read_length": "numeric",
        "median_read_length": "numeric",
        "read_length_stdev": "numeric",
        "n50": "numeric",
        "mean_qual": "numeric",
        "median_qual": "numeric",
        "average_identity": "numeric",
        "Reads >Q5_percentage": "numeric",
        "Reads >Q7_percentage": "numeric",
        "Reads >Q10_percentage": "numeric",
        "Reads >Q12_percentage": "numeric",
        "Reads >Q15_percentage": "numeric",
    }


@dataclass
class NanoStatsConfig:
    """Configuration for NanoStats parsing with default settings."""

    skip_categories: tuple = ("longest_read_(with_Q)", "highest_Q_read_(with_length)")
    column_types: Dict[str, str] = field(default_factory=get_default_column_types)
    required_metrics: tuple = (
        "number_of_reads",
        "number_of_bases",
        "median_read_length",
        "mean_read_length",
        "read_length_stdev",
        "n50",
        "mean_qual",
        "median_qual",
        "Reads_>Q5",
        "Reads_>Q7",
        "Reads_>Q10",
        "Reads_>Q12",
        "Reads_>Q15",
    )


def _parse_nanostats_file(file_path: Path) -> Dict[str, float]:
    """
    Parse a NanoStats.txt file and extract metrics.

    Args:
        file_path (Path): Path to NanoStats.txt file

    Returns:
        Dict[str, float]: Dictionary of metrics and their values

    Raises:
        ValueError: If required metrics are missing from the file
    """
    metrics = {}
    try:
        with open(file_path) as f:
            next(f)  # Skip header line
            for line in f:
                key, value = line.strip().split("\t")
                key = key.strip(":")

                if any(skip in key for skip in NanoStatsConfig.skip_categories):
                    continue

                if key.startswith("Reads >Q"):
                    match = re.search(r"\((\d+\.\d+)%\)", value)
                    if match:
                        clean_key = key.replace(" ", "_")
                        metrics[clean_key] = float(match.group(1)) / 100
                    continue

                try:
                    clean_value = value.split()[0].replace(",", "")
                    metrics[key.lower().replace(" ", "_")] = float(clean_value)
                except (ValueError, IndexError):
                    logger.warning(
                        f"Could not parse value for metric {key} in {file_path}"
                    )

        # Verify required metrics
        missing_metrics = [
            metric
            for metric in NanoStatsConfig.required_metrics
            if metric not in metrics
        ]
        if missing_metrics:
            raise ValueError(f"Missing required metrics: {missing_metrics}")

        return metrics

    except FileNotFoundError:
        logger.error(f"NanoStats file not found: {file_path}")
        raise
    except Exception as e:
        logger.error(f"Error parsing NanoStats file {file_path}: {str(e)}")
        raise


def _get_multiplexing_status(seq_summaries_dir: Path, sample_id: str) -> str:
    """
    Determine if a sample is multiplexed.

    Args:
        seq_summaries_dir (Path): Directory containing sequencing summaries
        sample_id (str): Sample identifier

    Returns:
        str: 'multiplex' or 'singleplex'
    """
    try:
        for dir_path in seq_summaries_dir.glob("*"):
            if not dir_path.is_dir():
                continue
            samples = dir_path.name.split("__")
            if sample_id in samples:
                return "multiplex" if len(samples) > 1 else "singleplex"
        return "singleplex"
    except Exception as e:
        logger.error(f"Error determining multiplexing status for {sample_id}: {str(e)}")
        raise


def _extract_sample_info(dir_path: Path) -> Tuple[str, str]:
    """
    Extract sample ID and basecall type from directory path.

    Args:
        dir_path (Path): Directory path containing sample information

    Returns:
        Tuple[str, str]: Tuple of (sample_id, basecall_type)

    Raises:
        ValueError: If directory name format is invalid
    """
    try:
        parts = dir_path.name.split("_")
        if len(parts) < 2:
            raise ValueError(f"Invalid directory name format: {dir_path.name}")
        sample_id = "_".join(parts[:-1])
        basecall = parts[-1]
        return sample_id, basecall
    except Exception as e:
        logger.error(f"Error extracting sample info from {dir_path}: {str(e)}")
        raise


def parse_nanostats(
    aligned_bams_dir: Path,
    seq_summaries_dir: Path,
) -> pl.DataFrame:
    """
    Parse NanoStats files and create a DataFrame with QC metrics.

    Args:
        aligned_bams_dir (Path): Directory containing aligned BAM files
        seq_summaries_dir (Path): Directory containing sequencing summaries

    Returns:
        pl.DataFrame: Polars DataFrame containing parsed metrics, sorted by sample ID

    Raises:
        FileNotFoundError: If input directories don't exist
        ValueError: If no valid samples are found
    """
    if not aligned_bams_dir.exists():
        raise FileNotFoundError(f"Aligned BAMs directory not found: {aligned_bams_dir}")
    if not seq_summaries_dir.exists():
        raise FileNotFoundError(
            f"Sequencing summaries directory not found: {seq_summaries_dir}"
        )

    data: List[Dict] = []
    sample_ids = set()

    # Collect all sample IDs
    for dir_path in aligned_bams_dir.glob("*"):
        if not dir_path.is_dir():
            continue
        sample_id, _ = _extract_sample_info(dir_path)
        sample_ids.add(sample_id)

    if not sample_ids:
        raise ValueError("No valid samples found in the input directory")

    # Create anonymised sample mapping
    sample_mapping = {
        sample_id: f"Sample {i+1}" for i, sample_id in enumerate(sorted(sample_ids))
    }

    # Process each sample directory
    for dir_path in aligned_bams_dir.glob("*"):
        if not dir_path.is_dir():
            continue

        try:
            nanostats_file = dir_path / "NanoStats.txt"
            sample_id, basecall = _extract_sample_info(dir_path)
            metrics = _parse_nanostats_file(nanostats_file)

            sample_data = {
                "sample": sample_id,
                "anonymised_sample": sample_mapping[sample_id],
                "basecall": basecall,
                "multiplexing": _get_multiplexing_status(seq_summaries_dir, sample_id),
                **metrics,
            }
            data.append(sample_data)
        except Exception as e:
            logger.error(f"Error processing directory {dir_path}: {str(e)}")
            continue

    if not data:
        raise ValueError("No valid data could be parsed from any sample")

    df = pl.DataFrame(data)
    return df.sort(["sample", "anonymised_sample"])


np_seq_summaries_dir = Path(
    "/scratch/prj/ppn_als_longread/ont-benchmark/qc/nanoplot/seq_summaries/"
)
np_aligned_bams_dir = Path(
    "/scratch/prj/ppn_als_longread/ont-benchmark/qc/nanoplot/aligned_bams/"
)

nanoplot_qc_metrics_df = parse_nanostats(
    aligned_bams_dir=np_aligned_bams_dir,
    seq_summaries_dir=np_seq_summaries_dir,
)
logger.info(f"Successfully processed {len(nanoplot_qc_metrics_df)} samples")

with pl.Config(tbl_rows=len(nanoplot_qc_metrics_df)):
    display(nanoplot_qc_metrics_df)
__main__ - INFO - Successfully processed 14 samples
shape: (14, 21)
sampleanonymised_samplebasecallmultiplexingnumber_of_readsnumber_of_basesnumber_of_bases_alignedfraction_bases_alignedmedian_read_lengthmean_read_lengthread_length_stdevn50average_identitymedian_identitymean_qualmedian_qualReads_>Q5Reads_>Q7Reads_>Q10Reads_>Q12Reads_>Q15
strstrstrstrf64f64f64f64f64f64f64f64f64f64f64f64f64f64f64f64f64
"A046_12""Sample 1""sup""multiplex"5.063989e63.3558e103.1483e100.94600.06626.86582.911608.097.598.917.920.90.9990.9980.9930.9770.892
"A048_09""Sample 2""sup""multiplex"7.088011e64.0078e103.7407e100.92996.05654.36755.611981.097.598.817.620.30.9990.9990.9940.9780.883
"A079_07""Sample 3""sup""multiplex"3.813948e63.1232e102.9834e101.05219.08189.08440.015713.097.498.817.820.50.9990.9980.9930.9780.892
"A081_91""Sample 4""sup""multiplex"3.278883e62.5565e102.4259e100.94066.07797.09011.516853.097.498.817.720.40.9990.9980.9930.9790.89
"A085_00""Sample 5""sup""multiplex"3.767749e62.7359e102.5779e100.94120.07261.48121.215061.097.298.717.119.80.9990.9980.9910.9690.857
"A097_92""Sample 6""sup""multiplex"4.264823e63.5497e103.3586e100.96204.08323.17593.414429.097.198.717.320.10.9990.9980.9910.9730.87
"A149_01""Sample 7""sup""singleplex"8.228301e64.7190e104.4520e100.93532.05735.16354.110895.097.498.817.620.30.9990.9990.9940.9760.877
"A153_01""Sample 8""sup""singleplex"7.346662e65.0667e104.8132e100.96660.06896.75705.610074.097.498.918.321.41.00.9990.9950.9830.905
"A153_06""Sample 9""sup""singleplex"1.1559255e77.8039e107.3883e100.95641.06751.26382.810973.097.498.917.820.70.9990.9980.9930.9780.892
"A154_04""Sample 10""sup""singleplex"9.270031e65.4649e105.1180e100.93996.05895.26062.210939.097.398.817.620.60.9990.9980.9920.9750.879
"A154_06""Sample 11""sup""singleplex"8.338801e65.7601e105.4980e101.06073.06907.66178.511330.097.398.717.620.21.00.9990.9950.980.882
"A157_02""Sample 12""sup""singleplex"8.250276e65.3573e105.1193e101.06075.06493.55350.59628.097.398.818.020.91.00.9990.9950.9810.893
"A160_96""Sample 13""sup""singleplex"1.1591344e78.3603e108.0174e101.06601.07212.56210.711126.097.498.817.820.61.00.9990.9950.9810.884
"A162_09""Sample 14""sup""singleplex"1.487444e78.9966e108.5759e101.03957.06048.36378.411325.097.598.817.720.31.00.9990.9950.9790.878

Sequencing Yield¶

1. Raw Yields¶

In [3]:
def _create_yield_plot(
    ax: plt.Axes,
    data: pl.DataFrame,
    x: str,
    y: str,
    hue: str,
    title: str,
    xlabel: str,
    ylabel: str,
) -> None:
    """
    Create a bar plot showing yield metrics.

    Args:
        ax (plt.Axes): Matplotlib axes object to plot on
        data (pl.DataFrame): Polars DataFrame containing the data
        x (str): Column name for x-axis
        y (str): Column name for y-axis
        hue (str): Column name for color grouping
        title (str): Plot title
        xlabel (str): X-axis label
        ylabel (str): Y-axis label

    Raises:
        ValueError: If required columns are not found in the DataFrame
    """
    try:
        # Validate input columns
        required_cols = {x, y, hue}
        if not required_cols.issubset(data.columns):
            missing = required_cols - set(data.columns)
            raise ValueError(f"Missing required columns: {missing}")

        # Create plot with sorted data
        sns.barplot(x=x, y=y, hue=hue, data=data, ax=ax)

        ax.set_title(title)
        ax.set_xlabel(xlabel)

        # Get the scale factor from the formatter
        formatter = ax.yaxis.get_major_formatter()
        if hasattr(formatter, "orderOfMagnitude"):
            scale = formatter.orderOfMagnitude
            ylabel = f"{ylabel} ($1×10^{{{scale}}}$)"

        ax.set_ylabel(ylabel)

        # Rotate x-axis labels
        for tick in ax.get_xticklabels():
            tick.set_rotation(45)
            tick.set_ha("right")

        # Adjust x-tick positions
        locs, _ = ax.get_xticks(), ax.get_xticklabels()
        ax.set_xticks([loc + 0.2 for loc in locs])

        # Create legend
        legend = ax.legend(title=hue.title())
        legend.get_title().set_weight("bold")

    except Exception as e:
        logger.error(f"Error creating yield plot: {str(e)}")
        raise


def plot_sample_yields(
    metrics_df: pl.DataFrame,
    basecall_type: str = "sup",
    figsize: Tuple[int, int] = (16, 6),
    dpi: int = 300,
    gs: Optional[gridspec.GridSpec] = None,
) -> Optional[plt.Figure]:
    """
    Create plots showing read and base yields for samples.

    Args:
        metrics_df (pl.DataFrame): Polars DataFrame containing metrics data
        basecall_type (str, optional): Basecall type to filter for. Defaults to "sup".
        figsize (Tuple[int, int], optional): Figure size. Defaults to (16, 6).
        dpi (int, optional): Figure DPI. Defaults to 300.
        gs (gridspec.GridSpec, optional): GridSpec for plotting within a larger figure.

    Returns:
        Optional[plt.Figure]: Figure object if created independently (no GridSpec provided)

    Raises:
        ValueError: If DataFrame doesn't contain required columns.
    """
    try:
        # Validate input data
        required_cols = {
            "basecall",
            "anonymised_sample",
            "number_of_reads",
            "number_of_bases",
            "multiplexing",
        }
        if not required_cols.issubset(metrics_df.columns):
            missing = required_cols - set(metrics_df.columns)
            raise ValueError(f"Missing required columns: {missing}")

        # Filter data
        yields_df = metrics_df.filter(pl.col("basecall") == basecall_type)

        if len(yields_df) == 0:
            raise ValueError(f"No data found for basecall_type: {basecall_type}")

        # Create figure and axes based on whether GridSpec is provided
        if gs is None:
            fig, (ax1, ax2) = plt.subplots(1, 2, figsize=figsize, dpi=dpi)
        else:
            fig = plt.gcf()
            ax1 = fig.add_subplot(gs[0, 0])
            ax2 = fig.add_subplot(gs[0, 1])

        # Create plots
        _create_yield_plot(
            ax1,
            yields_df,
            "anonymised_sample",
            "number_of_reads",
            "multiplexing",
            f"Read Yield per Sample",
            "Sample",
            "Number of Reads",
        )

        _create_yield_plot(
            ax2,
            yields_df,
            "anonymised_sample",
            "number_of_bases",
            "multiplexing",
            f"Base Yield per Sample",
            "Sample",
            "Number of Bases",
        )

        if gs is None:
            plt.tight_layout()
            return fig
        return None

    except Exception as e:
        logger.error(f"Error plotting sample yields: {str(e)}")
        raise


yields_plots = plot_sample_yields(nanoplot_qc_metrics_df)
No description has been provided for this image

Summary Stats¶

In [4]:
@dataclass
class YieldMetrics:
    """
    Data class for storing yield metrics statistics.
    """

    max: float
    min: float
    mean: float
    std: float
    median: float


@dataclass
class YieldStats:
    """
    Data class for storing read and base yield statistics.
    """

    reads: YieldMetrics
    bases: YieldMetrics


def _format_number_separator(num: float) -> str:
    """
    Format a number with thousand separators.

    Args:
        num (float): Number to format with thousand separators

    Returns:
        str: Formatted string representation of the number with thousand separators

    Examples:
        >>> _format_number_separator(1234567.89)
        '1,234,568'
    """
    return f"{num:,.0f}"


def _calculate_yield_stats(df: pl.DataFrame) -> YieldStats:
    """
    Calculate yield statistics from a Polars DataFrame.

    Args:
        df (pl.DataFrame): Input DataFrame containing yield metrics with columns
            'number_of_reads' and 'number_of_bases'

    Returns:
        YieldStats: Object containing read and base statistics

    Raises:
        Exception: If there's an error calculating statistics from the DataFrame
        KeyError: If required columns are missing from the DataFrame
    """
    try:
        reads_metrics = YieldMetrics(
            max=df.select(pl.col("number_of_reads").max()).item(),
            min=df.select(pl.col("number_of_reads").min()).item(),
            mean=df.select(pl.col("number_of_reads").mean()).item(),
            std=df.select(pl.col("number_of_reads").std()).item(),
            median=df.select(pl.col("number_of_reads").median()).item(),
        )

        bases_metrics = YieldMetrics(
            max=df.select(pl.col("number_of_bases").max()).item(),
            min=df.select(pl.col("number_of_bases").min()).item(),
            mean=df.select(pl.col("number_of_bases").mean()).item(),
            std=df.select(pl.col("number_of_bases").std()).item(),
            median=df.select(pl.col("number_of_bases").median()).item(),
        )

        return YieldStats(reads=reads_metrics, bases=bases_metrics)
    except Exception as e:
        logger.error(f"Error calculating yield statistics: {str(e)}")
        raise


def _print_yield_stats(stats: YieldStats, sample_type: str) -> None:
    """
    Print formatted yield statistics.

    Args:
        stats (YieldStats): YieldStats object containing statistics to print
        sample_type (str): Type of sample (Multiplexed/Singleplexed)

    Raises:
        Exception: If there's an error formatting or printing the statistics
    """
    try:
        logger.info(f"Printing statistics for {sample_type} samples")
        print(f"\n{sample_type} Samples Statistics:")
        print("=" * 40)

        for metric_name, metrics in [("Reads", stats.reads), ("Bases", stats.bases)]:
            print(f"\n{metric_name}:")
            for stat_name, value in vars(metrics).items():
                formatted_value = _format_number_separator(value)
                print(f"  {stat_name.capitalize():6s}: {formatted_value}")
    except Exception as e:
        logger.error(f"Error printing yield statistics: {str(e)}")
        raise


def _calculate_percentage_increase(
    singleplex_val: float, multiplex_val: float
) -> float:
    """
    Calculate percentage increase between two values.

    Args:
        singleplex_val (float): Value from singleplex samples
        multiplex_val (float): Value from multiplex samples

    Returns:
        float: Percentage increase between the two values

    Raises:
        ZeroDivisionError: If multiplex value is zero
        Exception: For other calculation errors

    Examples:
        >>> _calculate_percentage_increase(200, 100)
        100.0
    """
    try:
        return ((singleplex_val - multiplex_val) / multiplex_val) * 100
    except ZeroDivisionError:
        logger.error("Cannot calculate percentage increase: multiplex value is zero")
        raise
    except Exception as e:
        logger.error(f"Error calculating percentage increase: {str(e)}")
        raise


def analyze_yields(df: pl.DataFrame) -> None:
    """
    Analyze and print yield statistics for multiplexed and singleplexed samples.

    Args:
        df (pl.DataFrame): Input DataFrame containing yield metrics with columns:
            - multiplexing: str ('singleplex' or 'multiplex')
            - basecall: str ('sup' or other)
            - number_of_reads: int/float
            - number_of_bases: int/float

    Raises:
        Exception: If there's an error during analysis
        ValueError: If required data is missing from the DataFrame
    """
    try:
        singleplex_yields = df.filter(
            (pl.col("multiplexing") == "singleplex") & (pl.col("basecall") == "sup")
        )
        multiplex_yields = df.filter(
            (pl.col("multiplexing") == "multiplex") & (pl.col("basecall") == "sup")
        )

        if singleplex_yields.height == 0 or multiplex_yields.height == 0:
            logger.warning("No data found for either singleplex or multiplex samples")
            return

        singleplex_stats = _calculate_yield_stats(singleplex_yields)
        multiplex_stats = _calculate_yield_stats(multiplex_yields)

        _print_yield_stats(singleplex_stats, "Singleplexed")
        _print_yield_stats(multiplex_stats, "Multiplexed")

        print("\nPercentage Increase (Singleplexed vs Multiplexed):")
        print("=" * 40)

        increase_reads = _calculate_percentage_increase(
            singleplex_stats.reads.mean, multiplex_stats.reads.mean
        )
        increase_bases = _calculate_percentage_increase(
            singleplex_stats.bases.mean, multiplex_stats.bases.mean
        )

        print(f"Mean Number of Reads: {increase_reads:6.2f}%")
        print(f"Mean Number of Bases: {increase_bases:6.2f}%")

        logger.info("Yield analysis completed successfully")

    except Exception as e:
        logger.error(f"Error in yield analysis: {str(e)}")
        raise


analyze_yields(nanoplot_qc_metrics_df)
__main__ - INFO - Printing statistics for Singleplexed samples
__main__ - INFO - Printing statistics for Multiplexed samples
__main__ - INFO - Yield analysis completed successfully
Singleplexed Samples Statistics:
========================================

Reads:
  Max   : 14,874,440
  Min   : 7,346,662
  Mean  : 9,932,389
  Std   : 2,541,662
  Median: 8,804,416

Bases:
  Max   : 89,965,577,035
  Min   : 47,190,301,151
  Mean  : 64,410,968,490
  Std   : 16,697,589,113
  Median: 56,124,706,342

Multiplexed Samples Statistics:
========================================

Reads:
  Max   : 7,088,011
  Min   : 3,278,883
  Mean  : 4,546,234
  Std   : 1,382,487
  Median: 4,039,386

Bases:
  Max   : 40,077,847,821
  Min   : 25,565,442,444
  Mean  : 32,214,905,852
  Std   : 5,350,881,127
  Median: 32,395,219,886

Percentage Increase (Singleplexed vs Multiplexed):
========================================
Mean Number of Reads: 118.48%
Mean Number of Bases:  99.94%

2. Read Lengths¶

In [5]:
def _create_length_subplot(
    data: pl.DataFrame, ax: plt.Axes, title: str, hue: str
) -> None:
    """
    Create a subplot showing read length metrics.

    Args:
        data (pl.DataFrame): DataFrame containing the plot data
        ax (plt.Axes): Matplotlib axes object to plot on
        title (str): Plot title
        hue (str): Column name for color grouping

    Raises:
        ValueError: If required columns are missing from DataFrame
    """
    try:
        sns.barplot(
            x="anonymised_sample",
            y="read_length",
            hue=hue,
            data=data,
            errorbar=None,
            ax=ax,
        )

        ax.set_title(title)
        ax.set_xlabel("Sample")

        formatter = ax.yaxis.get_major_formatter()
        ylabel = "Read Length (bp)"
        ax.set_ylabel(ylabel)

        for tick in ax.get_xticklabels():
            tick.set_rotation(45)
            tick.set_ha("right")

        locs, _ = ax.get_xticks(), ax.get_xticklabels()
        ax.set_xticks([loc + 0.2 for loc in locs])

        legend = ax.legend(title=hue.title())
        legend.get_title().set_weight("bold")

    except Exception as e:
        logger.error(f"Error creating length subplot: {str(e)}")
        raise


def plot_read_lengths(
    metrics_df: pl.DataFrame,
    figsize: Tuple[int, int] = (16, 6),
    dpi: int = 300,
    gs: Optional[gridspec.GridSpec] = None,
) -> Optional[plt.Figure]:
    """
    Create plots showing read length distributions for samples.

    Args:
        metrics_df (pl.DataFrame): Input DataFrame containing metrics data
        figsize (Tuple[int, int], optional): Figure size. Defaults to (16, 6).
        dpi (int, optional): Figure DPI. Defaults to 300.
        gs (gridspec.GridSpec, optional): GridSpec for plotting within a larger figure.

    Returns:
        Optional[plt.Figure]: Figure object if created independently.

    Raises:
        ValueError: If DataFrame doesn't contain required columns.
    """
    try:
        required_cols = {
            "basecall",
            "sample",
            "anonymised_sample",
            "multiplexing",
            "mean_read_length",
            "median_read_length",
        }
        if not required_cols.issubset(metrics_df.columns):
            missing = required_cols - set(metrics_df.columns)
            raise ValueError(f"Missing required columns: {missing}")

        # Filter and prepare data
        plot_data = (
            metrics_df.filter(pl.col("basecall") == "sup")
            .select(
                [
                    "sample",
                    "anonymised_sample",
                    "multiplexing",
                    "mean_read_length",
                    "median_read_length",
                ]
            )
            .unpivot(
                index=["sample", "anonymised_sample", "multiplexing"],
                on=["mean_read_length", "median_read_length"],
                variable_name="read_length_type",
                value_name="read_length",
            )
        )

        if gs is None:
            fig, axes = plt.subplots(1, 2, figsize=figsize, sharey=True, dpi=dpi)
        else:
            fig = plt.gcf()
            axes = [fig.add_subplot(gs[0, 0]), fig.add_subplot(gs[0, 1])]

        for ax, read_length_type in zip(
            axes, ["mean_read_length", "median_read_length"]
        ):
            title = read_length_type.replace("_", " ").title()
            data = plot_data.filter(pl.col("read_length_type") == read_length_type)
            _create_length_subplot(data, ax, title, hue="multiplexing")

            if ax != axes[0]:
                ax.yaxis.set_tick_params(labelleft=True)

        if gs is None:
            plt.tight_layout()
            return fig
        return None

    except Exception as e:
        logger.error(f"Error plotting read lengths: {str(e)}")
        raise


read_lengths_plot = plot_read_lengths(nanoplot_qc_metrics_df)
No description has been provided for this image
In [6]:
def load_nanoplot_data(base_dir: Path, metrics_df: pl.DataFrame) -> pl.DataFrame:
    """
    Load NanoPlot data from pickle files and combine with metrics.

    Args:
        base_dir (Path): Base directory containing NanoPlot data files
        metrics_df (pl.DataFrame): DataFrame containing sample metrics

    Returns:
        pl.DataFrame: Combined NanoPlot data for all samples

    Raises:
        FileNotFoundError: If pickle file is not found
        ValueError: If required columns are missing
    """
    required_columns = ("readIDs", "quals", "lengths", "mapQ")
    data_list = []

    try:
        for row in metrics_df.iter_rows(named=True):
            sample_dir = f"{row['sample']}_{row['basecall']}"
            pickle_path = base_dir / sample_dir / "NanoPlot-data.pickle"

            if not pickle_path.is_file():
                logger.warning(f"Pickle file not found: {pickle_path}")
                continue

            with open(pickle_path, "rb") as file:
                nanoplot_data = pickle.load(file)

            sample_df = pl.DataFrame(nanoplot_data).select(required_columns)
            sample_df = sample_df.with_columns(
                [
                    pl.lit(row["anonymised_sample"]).alias("anonymised_sample"),
                    pl.lit(row["basecall"]).alias("basecall"),
                ]
            )
            data_list.append(sample_df)

        if not data_list:
            raise ValueError("No valid data found in any pickle files")

        return pl.concat(data_list)

    except Exception as e:
        logger.error(f"Error loading NanoPlot data: {str(e)}")
        raise


def process_aligned_nanoplot_data(
    nanoplot_data: pl.DataFrame, metrics_df: pl.DataFrame
) -> pl.DataFrame:
    """
    Process NanoPlot data by merging with metrics and binning read lengths.

    Args:
        nanoplot_data (pl.DataFrame): Raw NanoPlot data
        metrics_df (pl.DataFrame): Metrics DataFrame

    Returns:
        pl.DataFrame: Processed DataFrame with binned lengths

    Raises:
        ValueError: If required columns are missing
    """
    try:
        metrics_subset = metrics_df.select(
            ["anonymised_sample", "multiplexing", "basecall", "number_of_reads"]
        )

        processed_data = nanoplot_data.join(
            metrics_subset, on=["anonymised_sample", "basecall"]
        )

        max_length = processed_data.select(pl.col("lengths").max()).item()
        bins = np.logspace(np.log10(10), np.log10(max_length), num=100)

        processed_data = processed_data.with_columns(
            [pl.col("lengths").cut(bins).alias("length_bin")]
        )

        return processed_data

    except Exception as e:
        logger.error(f"Error processing NanoPlot data: {str(e)}")
        raise


def calculate_read_length_distribution(
    processed_data: pl.DataFrame, basecall_type: str = "sup"
) -> pl.DataFrame:
    """
    Calculate read length distribution statistics.

    Args:
        processed_data (pl.DataFrame): Processed NanoPlot data
        basecall_type (str, optional): Basecall type to filter. Defaults to "sup"

    Returns:
        pl.DataFrame: Length distribution statistics

    Raises:
        ValueError: If required columns are missing
    """
    try:
        filtered_data = processed_data.filter(pl.col("basecall") == basecall_type)

        # Extract bin edges from the categorical length_bin column
        bin_categories = filtered_data.select(pl.col("length_bin").unique()).to_series()
        bin_edges = [float(edge.split(",")[0][1:]) for edge in bin_categories]
        bin_centers = [
            (bin_edges[i] + bin_edges[i + 1]) / 2 for i in range(len(bin_edges) - 1)
        ]

        # Create a mapping DataFrame to calculate bin centers
        mapping_df = pl.DataFrame(
            {
                "length_bin": bin_categories[:-1],  # Exclude the last bin edge
                "bin_center": bin_centers,
            }
        )

        length_dist = (
            filtered_data.group_by(["anonymised_sample", "length_bin", "multiplexing"])
            .agg(pl.len().alias("count"))
            .join(
                filtered_data.select(["anonymised_sample", "number_of_reads"]).unique(),
                on="anonymised_sample",
            )
        )

        # Add percentage and join with mapping DataFrame to get bin centers
        length_dist = length_dist.with_columns(
            [(pl.col("count") / pl.col("number_of_reads") * 100).alias("percentage")]
        ).join(mapping_df, on="length_bin", how="left")

        return length_dist

    except Exception as e:
        logger.error(f"Error calculating length distribution: {str(e)}")
        raise


def plot_read_length_distribution(
    length_dist: pl.DataFrame,
    max_length: float,
    figsize: Tuple[int, int] = (14, 6),
    dpi: int = 300,
    x_scale: str = "log",
    x_min: int = 10,
    num_ticks: int = 20,
    line_alpha: float = 0.8,
    gs: Optional[gridspec.GridSpec] = None,
) -> Optional[plt.Figure]:
    """
    Plot the distribution of read lengths across samples.

    Args:
        length_dist (pl.DataFrame): DataFrame containing length distribution data
        max_length (float): Maximum read length for x-axis limit
        figsize (Tuple[int, int], optional): Figure size. Defaults to (14, 6).
        dpi (int, optional): Figure DPI. Defaults to 300.
        x_scale (str, optional): Scale for x-axis. Defaults to "log".
        x_min (int, optional): Minimum x-axis value. Defaults to 10.
        num_ticks (int, optional): Number of x-axis ticks. Defaults to 20.
        line_alpha (float, optional): Line transparency. Defaults to 0.8.
        gs (gridspec.GridSpec, optional): GridSpec for plotting within a larger figure.

    Returns:
        Optional[Figure]: Figure object if created independently.

    Raises:
        ValueError: If required columns are missing.
    """
    try:
        required_cols = {
            "anonymised_sample",
            "bin_center",
            "percentage",
            "multiplexing",
        }
        if not all(col in length_dist.columns for col in required_cols):
            missing = required_cols - set(length_dist.columns)
            raise ValueError(f"Missing required columns: {missing}")

        if gs is None:
            fig, ax = plt.subplots(figsize=figsize, dpi=dpi)
        else:
            fig = plt.gcf()
            ax = fig.add_subplot(gs[0, 0])

        # Filter and prepare data
        non_zero_samples = (
            length_dist.group_by("anonymised_sample")
            .agg(pl.col("percentage").sum())
            .filter(pl.col("percentage") > 0)
            .select("anonymised_sample")
        )

        filtered_dist = (
            length_dist.join(non_zero_samples, on="anonymised_sample")
            .with_columns(
                [
                    pl.col("anonymised_sample")
                    .str.extract(r"(\d+)")
                    .cast(pl.Int32)
                    .alias("sample_num")
                ]
            )
            .sort("sample_num")
            .rename(
                {
                    "multiplexing": r"$\mathbf{Multiplexing}$",
                    "anonymised_sample": r"$\mathbf{Sample}$",
                }
            )
        )

        sns.lineplot(
            data=filtered_dist,
            x="bin_center",
            y="percentage",
            hue=r"$\mathbf{Sample}$",
            style=r"$\mathbf{Multiplexing}$",
            alpha=line_alpha,
            ax=ax,
        )

        ax.legend(loc="upper left", bbox_to_anchor=(1, 1.1))

        ax.set_xscale(x_scale)
        ax.set_xlabel("Read Length (bp)")
        ax.set_ylabel("Proportion of Reads (%)")
        ax.set_title("Distribution of Read Lengths")

        # Set x-axis ticks
        tick_positions = np.logspace(
            np.log10(x_min), np.log10(max_length), num=num_ticks
        )
        ax.set_xticks(tick_positions)
        ax.set_xticklabels([f"{int(tick):,}" for tick in tick_positions])

        # Set axis limits
        ax.set_xlim(left=x_min, right=max_length)

        if gs is None:
            plt.tight_layout()
            return fig
        return None

    except Exception as e:
        logger.error(f"Error plotting read length distribution: {str(e)}")
        raise


nanoplot_aligned_metrics = load_nanoplot_data(
    np_aligned_bams_dir, nanoplot_qc_metrics_df
)

processed_aligned_nanoplot_df = process_aligned_nanoplot_data(
    nanoplot_aligned_metrics, nanoplot_qc_metrics_df
)

max_read_length = processed_aligned_nanoplot_df.select(pl.col("lengths").max()).item()

read_length_distribution = calculate_read_length_distribution(
    processed_aligned_nanoplot_df
)

read_length_dist_plot = plot_read_length_distribution(
    read_length_distribution, max_read_length
)
No description has been provided for this image
In [7]:
@dataclass
class LengthMetrics:
    """
    Data class for storing read length metrics statistics.
    """

    max: float
    min: float
    mean: float
    std: float
    median: float


def _calculate_length_stats(df: pl.DataFrame) -> LengthMetrics:
    """
    Calculate read length statistics from a Polars DataFrame.

    Args:
        df (pl.DataFrame): Input DataFrame containing length metrics with column 'lengths'

    Returns:
        LengthMetrics: Object containing read length statistics

    Raises:
        Exception: If there's an error calculating statistics from the DataFrame
        KeyError: If required column is missing from the DataFrame
    """
    try:
        return LengthMetrics(
            max=df.select(pl.col("lengths").max()).item(),
            min=df.select(pl.col("lengths").min()).item(),
            mean=df.select(pl.col("lengths").mean()).item(),
            std=df.select(pl.col("lengths").std()).item(),
            median=df.select(pl.col("lengths").median()).item(),
        )
    except Exception as e:
        logger.error(f"Error calculating length statistics: {str(e)}")
        raise


def _print_length_stats(stats: LengthMetrics, sample_type: str) -> None:
    """
    Print formatted read length statistics.

    Args:
        stats (LengthMetrics): LengthMetrics object containing statistics to print
        sample_type (str): Type of sample (Multiplexed/Singleplexed)

    Raises:
        Exception: If there's an error formatting or printing the statistics
    """
    try:
        logger.info(f"Printing length statistics for {sample_type} samples")
        print(f"\n{sample_type} Samples Statistics:")
        print("=" * 40)
        print("\nRead Lengths:")
        for stat_name, value in vars(stats).items():
            formatted_value = _format_number_separator(value)
            print(f"  {stat_name.capitalize():6s}: {formatted_value}")
    except Exception as e:
        logger.error(f"Error printing length statistics: {str(e)}")
        raise


def analyze_lengths(df: pl.DataFrame) -> None:
    """
    Analyze and print read length statistics for multiplexed and singleplexed samples.

    Args:
        df (pl.DataFrame): Input DataFrame containing length metrics with columns:
            - multiplexing: str ('singleplex' or 'multiplex')
            - basecall: str ('sup' or other)
            - lengths: int/float

    Raises:
        Exception: If there's an error during analysis
        ValueError: If required data is missing from the DataFrame
    """
    try:
        singleplex_lengths = df.filter(
            (pl.col("multiplexing") == "singleplex") & (pl.col("basecall") == "sup")
        )
        multiplex_lengths = df.filter(
            (pl.col("multiplexing") == "multiplex") & (pl.col("basecall") == "sup")
        )

        if singleplex_lengths.height == 0 or multiplex_lengths.height == 0:
            logger.warning("No data found for either singleplex or multiplex samples")
            return

        singleplex_stats = _calculate_length_stats(singleplex_lengths)
        multiplex_stats = _calculate_length_stats(multiplex_lengths)

        _print_length_stats(singleplex_stats, "Singleplexed")
        _print_length_stats(multiplex_stats, "Multiplexed")

        print("\nPercentage Increase (Singleplexed vs Multiplexed):")
        print("=" * 40)

        for stat_name in ["mean", "median"]:
            increase = _calculate_percentage_increase(
                getattr(singleplex_stats, stat_name),
                getattr(multiplex_stats, stat_name),
            )
            print(f"{stat_name.capitalize():6s} Read Length: {increase:6.2f}%")

        logger.info("Length analysis completed successfully")

    except Exception as e:
        logger.error(f"Error in length analysis: {str(e)}")
        raise


analyze_lengths(processed_aligned_nanoplot_df)
__main__ - INFO - Printing length statistics for Singleplexed samples
__main__ - INFO - Printing length statistics for Multiplexed samples
__main__ - INFO - Length analysis completed successfully
Singleplexed Samples Statistics:
========================================

Read Lengths:
  Max   : 836,896
  Min   : 40
  Mean  : 6,485
  Std   : 6,155
  Median: 5,369

Multiplexed Samples Statistics:
========================================

Read Lengths:
  Max   : 387,118
  Min   : 40
  Mean  : 7,086
  Std   : 7,668
  Median: 4,225

Percentage Increase (Singleplexed vs Multiplexed):
========================================
Mean   Read Length:  -8.48%
Median Read Length:  27.08%

3. Combined Plots¶

In [8]:
def create_combined_yield_plot(
    metrics_df: pl.DataFrame, figsize: Tuple[int, int] = (12, 10), dpi: int = 300
) -> plt.Figure:
    """
    Create a combined plot showing yields, read lengths, and length distribution.

    Args:
        metrics_df (pl.DataFrame): DataFrame containing metrics data
        figsize (Tuple[int, int], optional): Figure size. Defaults to (12, 12).
        dpi (int, optional): DPI for the figure. Defaults to 300.

    Returns:
        plt.Figure: Combined figure object

    Raises:
        ValueError: If required data is missing
    """
    try:
        # Create figure and GridSpec
        fig = plt.figure(figsize=figsize, dpi=dpi)
        gs = fig.add_gridspec(3, 2, height_ratios=[0.6, 0.9, 1.2])

        # Plot yields (A and B)
        plot_sample_yields(metrics_df, gs=gs)

        # Plot read lengths (C and D)
        plot_read_lengths(
            metrics_df, gs=gridspec.GridSpecFromSubplotSpec(1, 2, gs[1, :])
        )

        # Plot read length distribution (E)
        plot_read_length_distribution(
            read_length_distribution,
            max_read_length,
            gs=gridspec.GridSpecFromSubplotSpec(1, 1, gs[2, :]),
        )

        # Add panel labels
        for i, label in enumerate(["A", "B", "C", "D", "E"]):
            ax = fig.axes[i]
            ax.text(
                -0.05,
                1.07,
                label,
                transform=ax.transAxes,
                fontsize=12,
                fontweight="bold",
                va="top",
            )

        # Remove redundant legends
        for ax in fig.axes[1:4]:
            ax.get_legend().remove()

        fig.set_constrained_layout(True)
        return fig

    except Exception as e:
        logger.error(f"Error creating combined yield plot: {str(e)}")
        raise


combined_yield_plot = create_combined_yield_plot(nanoplot_qc_metrics_df)
No description has been provided for this image

Read Quality¶

1. Basecalling Quality¶

In [9]:
def calculate_base_quality_distribution(
    processed_data: pl.DataFrame, basecall_type: str = "sup"
) -> pl.DataFrame:
    """
    Calculate distribution of base quality scores across samples.

    Args:
        processed_data (pl.DataFrame): Processed NanoPlot data
        basecall_type (str, optional): Basecall type to filter. Defaults to "sup"

    Returns:
        pl.DataFrame: Quality distribution statistics

    Raises:
        ValueError: If required columns are missing
    """
    try:
        filtered_data = processed_data.filter(pl.col("basecall") == basecall_type)

        min_qual = filtered_data.select(pl.col("quals").min()).item()
        max_qual = filtered_data.select(pl.col("quals").max()).item()
        bins = np.arange(min_qual, max_qual + 0.5, 0.5)

        quality_dist = (
            filtered_data.with_columns([pl.col("quals").cut(bins).alias("quals_bin")])
            .group_by(["anonymised_sample", "quals_bin", "multiplexing"])
            .agg(pl.len().alias("count"))
            .join(
                filtered_data.select(["anonymised_sample", "number_of_reads"]).unique(),
                on="anonymised_sample",
            )
        )

        # Calculate bin centers and percentages
        bin_categories = quality_dist.select(pl.col("quals_bin").unique()).to_series()
        bin_edges = [float(edge.split(",")[0][1:]) for edge in bin_categories]
        bin_centers = [
            (bin_edges[i] + bin_edges[i + 1]) / 2 for i in range(len(bin_edges) - 1)
        ]

        mapping_df = pl.DataFrame(
            {"quals_bin": bin_categories, "quals_bin_lower": bin_edges}
        )

        quality_dist = quality_dist.with_columns(
            [(pl.col("count") / pl.col("number_of_reads") * 100).alias("percentage")]
        ).join(
            mapping_df,
            on="quals_bin",
            how="left",
        )

        return quality_dist

    except Exception as e:
        logger.error(f"Error calculating quality distribution: {str(e)}")
        raise


def plot_base_quality_distribution(
    quality_dist: pl.DataFrame,
    figsize: Tuple[int, int] = (14, 6),
    dpi: int = 300,
    line_alpha: float = 0.8,
    gs: Optional[gridspec.GridSpec] = None,
) -> Optional[plt.Figure]:
    """
    Plot the distribution of base quality scores across samples.

    Args:
        quality_dist (pl.DataFrame): DataFrame containing quality distribution data
        figsize (Tuple[int, int], optional): Figure size. Defaults to (14, 6).
        dpi (int, optional): Figure DPI. Defaults to 300.
        line_alpha (float, optional): Line transparency. Defaults to 0.8.
        gs (gridspec.GridSpec, optional): GridSpec for plotting within a larger figure.

    Returns:
        Optional[Figure]: Figure object if created independently.

    Raises:
        ValueError: If required columns are missing
    """
    try:
        required_cols = {
            "anonymised_sample",
            "quals_bin_lower",
            "percentage",
            "multiplexing",
        }
        if not all(col in quality_dist.columns for col in required_cols):
            missing = required_cols - set(quality_dist.columns)
            raise ValueError(f"Missing required columns: {missing}")

        if gs is None:
            fig, ax = plt.subplots(figsize=figsize, dpi=dpi)
        else:
            fig = plt.gcf()
            ax = fig.add_subplot(gs[0, 0])

        plot_data = (
            quality_dist.with_columns(
                [
                    pl.col("anonymised_sample")
                    .str.extract(r"(\d+)")
                    .cast(pl.Int32)
                    .alias("sample_num")
                ]
            )
            .sort("sample_num")
            .rename(
                {
                    "multiplexing": r"$\mathbf{Multiplexing}$",
                    "anonymised_sample": r"$\mathbf{Sample}$",
                }
            )
        )

        sns.lineplot(
            data=plot_data,
            x="quals_bin_lower",
            y="percentage",
            hue=r"$\mathbf{Sample}$",
            style=r"$\mathbf{Multiplexing}$",
            alpha=line_alpha,
            ax=ax,
        )

        ax.legend(loc="upper right", bbox_to_anchor=(1.02, 1.1))
        ax.set_xlabel("Quality Score")
        ax.set_ylabel("Proportion of Reads (%)")
        ax.set_title("Distribution of Base Quality Scores")

        max_qual = int(plot_data["quals_bin_lower"].max())
        tick_positions = np.arange(0, max_qual + 1, 5)
        ax.set_xticks(tick_positions)
        ax.set_xticklabels(tick_positions)

        if gs is None:
            plt.tight_layout()
            return fig
        else:
            ax.get_legend().remove()
            return None

    except Exception as e:
        logger.error(f"Error plotting quality distribution: {str(e)}")
        raise


base_quality_distribution_df = calculate_base_quality_distribution(
    processed_aligned_nanoplot_df
)

base_quality_dist_plot = plot_base_quality_distribution(base_quality_distribution_df)
No description has been provided for this image
In [10]:
def prepare_qscore_percentage_data(metrics_df: pl.DataFrame) -> pl.DataFrame:
    """
    Prepare QScore percentage data for visualization.

    Args:
        metrics_df (pl.DataFrame): DataFrame containing QC metrics

    Returns:
        pl.DataFrame: Processed QScore percentage data

    Raises:
        ValueError: If required columns are missing
    """
    try:
        qscore_columns = [
            "Reads_>Q5",
            "Reads_>Q7",
            "Reads_>Q10",
            "Reads_>Q12",
            "Reads_>Q15",
        ]

        # Verify required columns exist
        missing_cols = [col for col in qscore_columns if col not in metrics_df.columns]
        if missing_cols:
            raise ValueError(f"Missing required columns: {missing_cols}")

        # Filter and unpivot DataFrame
        qscore_df = (
            metrics_df.filter(pl.col("basecall") == "sup")
            .select(
                [
                    "anonymised_sample",
                    "multiplexing",
                    *qscore_columns,
                ]
            )
            .unpivot(
                index=["anonymised_sample", "multiplexing"],
                on=qscore_columns,
                variable_name="Quality_Score",
                value_name="Percentage",
            )
            .with_columns(
                [
                    # Extract just the Qn part from the Quality_Score column
                    pl.col("Quality_Score")
                    .str.extract(r">Q(\d+)")
                    .map_elements(
                        lambda x: f"Q{x}" if x is not None else None,
                        return_dtype=pl.Utf8,
                    )
                    .alias("Quality_Score")
                ]
            )
        )

        return qscore_df

    except Exception as e:
        logger.error(f"Error preparing QScore percentage data: {str(e)}")
        raise


def plot_qscore_percentage(
    qscore_df: pl.DataFrame,
    figsize: Tuple[int, int] = (20, 6),
    dpi: int = 300,
    gs: Optional[gridspec.GridSpec] = None,
) -> Optional[plt.Figure]:
    """
    Plot QScore percentage distribution across samples.

    Args:
        qscore_df (pl.DataFrame): DataFrame containing QScore percentage data
        figsize (Tuple[int, int], optional): Figure size. Defaults to (20, 6).
        dpi (int, optional): Figure DPI. Defaults to 300.
        gs (gridspec.GridSpec, optional): GridSpec for plotting within a larger figure.

    Returns:
        Optional[Figure]: Figure object if created independently.

    Raises:
        ValueError: If required columns are missing
    """
    try:
        required_cols = {
            "anonymised_sample",
            "multiplexing",
            "Quality_Score",
            "Percentage",
        }
        if not all(col in qscore_df.columns for col in required_cols):
            missing = required_cols - set(qscore_df.columns)
            raise ValueError(f"Missing required columns: {missing}")

        if gs is None:
            fig, (ax1, ax2) = plt.subplots(1, 2, figsize=figsize, dpi=dpi)
        else:
            fig = plt.gcf()
            ax1 = fig.add_subplot(gs[0])
            ax2 = fig.add_subplot(gs[1])

        plot_data = qscore_df.with_columns(
            [
                pl.col("anonymised_sample")
                .str.extract(r"(\d+)")
                .cast(pl.Int32)
                .alias("sample_num"),
                pl.col("anonymised_sample").alias(r"$\mathbf{Sample}$"),
            ]
        ).sort("sample_num")

        quality_score_order = ["Q5", "Q7", "Q10", "Q12", "Q15"]

        for ax, multiplex_type in zip([ax1, ax2], ["multiplex", "singleplex"]):
            data = plot_data.filter(pl.col("multiplexing") == multiplex_type)

            sns.barplot(
                data=data,
                x=r"$\mathbf{Sample}$",
                y="Percentage",
                hue="Quality_Score",
                errorbar=None,
                ax=ax,
                hue_order=quality_score_order,
            )

            ax.set_xlabel("Sample")
            ax.set_ylabel("Proportion of Reads (%)")
            ax.set_title(
                f"Percentage of Reads Above Quality Scores\n{multiplex_type.capitalize()} Samples"
            )
            legend = (
                ax.legend(title="Quality Score", loc="lower right")
                .get_title()
                .set_fontweight("bold")
            )

            for tick in ax.get_xticklabels():
                tick.set_rotation(45)
                tick.set_ha("right")

            locs, labels = ax.get_xticks(), ax.get_xticklabels()
            ax.set_xticks([loc + 0.1 for loc in locs])

        if gs is None:
            plt.tight_layout()
            return fig
        else:
            ax1.get_legend().remove()
            ax2.legend(
                title="Quality Score", bbox_to_anchor=(1, 1.05), loc="upper left"
            ).get_title().set_fontweight("bold")
            return None

    except Exception as e:
        logger.error(f"Error plotting QScore percentage distribution: {str(e)}")
        raise


qscore_percentage_df = prepare_qscore_percentage_data(nanoplot_qc_metrics_df)
qscore_percentage_plot = plot_qscore_percentage(qscore_percentage_df)
No description has been provided for this image
In [11]:
with pl.Config(tbl_rows=len(qscore_percentage_df)):
    display(qscore_percentage_df)
shape: (70, 4)
anonymised_samplemultiplexingQuality_ScorePercentage
strstrstrf64
"Sample 1""multiplex""Q5"0.999
"Sample 2""multiplex""Q5"0.999
"Sample 3""multiplex""Q5"0.999
"Sample 4""multiplex""Q5"0.999
"Sample 5""multiplex""Q5"0.999
"Sample 6""multiplex""Q5"0.999
"Sample 7""singleplex""Q5"0.999
"Sample 8""singleplex""Q5"1.0
"Sample 9""singleplex""Q5"0.999
"Sample 10""singleplex""Q5"0.999
"Sample 11""singleplex""Q5"1.0
"Sample 12""singleplex""Q5"1.0
"Sample 13""singleplex""Q5"1.0
"Sample 14""singleplex""Q5"1.0
"Sample 1""multiplex""Q7"0.998
"Sample 2""multiplex""Q7"0.999
"Sample 3""multiplex""Q7"0.998
"Sample 4""multiplex""Q7"0.998
"Sample 5""multiplex""Q7"0.998
"Sample 6""multiplex""Q7"0.998
"Sample 7""singleplex""Q7"0.999
"Sample 8""singleplex""Q7"0.999
"Sample 9""singleplex""Q7"0.998
"Sample 10""singleplex""Q7"0.998
"Sample 11""singleplex""Q7"0.999
"Sample 12""singleplex""Q7"0.999
"Sample 13""singleplex""Q7"0.999
"Sample 14""singleplex""Q7"0.999
"Sample 1""multiplex""Q10"0.993
"Sample 2""multiplex""Q10"0.994
"Sample 3""multiplex""Q10"0.993
"Sample 4""multiplex""Q10"0.993
"Sample 5""multiplex""Q10"0.991
"Sample 6""multiplex""Q10"0.991
"Sample 7""singleplex""Q10"0.994
"Sample 8""singleplex""Q10"0.995
"Sample 9""singleplex""Q10"0.993
"Sample 10""singleplex""Q10"0.992
"Sample 11""singleplex""Q10"0.995
"Sample 12""singleplex""Q10"0.995
"Sample 13""singleplex""Q10"0.995
"Sample 14""singleplex""Q10"0.995
"Sample 1""multiplex""Q12"0.977
"Sample 2""multiplex""Q12"0.978
"Sample 3""multiplex""Q12"0.978
"Sample 4""multiplex""Q12"0.979
"Sample 5""multiplex""Q12"0.969
"Sample 6""multiplex""Q12"0.973
"Sample 7""singleplex""Q12"0.976
"Sample 8""singleplex""Q12"0.983
"Sample 9""singleplex""Q12"0.978
"Sample 10""singleplex""Q12"0.975
"Sample 11""singleplex""Q12"0.98
"Sample 12""singleplex""Q12"0.981
"Sample 13""singleplex""Q12"0.981
"Sample 14""singleplex""Q12"0.979
"Sample 1""multiplex""Q15"0.892
"Sample 2""multiplex""Q15"0.883
"Sample 3""multiplex""Q15"0.892
"Sample 4""multiplex""Q15"0.89
"Sample 5""multiplex""Q15"0.857
"Sample 6""multiplex""Q15"0.87
"Sample 7""singleplex""Q15"0.877
"Sample 8""singleplex""Q15"0.905
"Sample 9""singleplex""Q15"0.892
"Sample 10""singleplex""Q15"0.879
"Sample 11""singleplex""Q15"0.882
"Sample 12""singleplex""Q15"0.893
"Sample 13""singleplex""Q15"0.884
"Sample 14""singleplex""Q15"0.878
In [12]:
@dataclass
class QualityMetrics:
    """
    Data class for storing base quality metrics statistics.
    """

    max: float
    min: float
    mean: float
    std: float
    median: float


def _calculate_quality_stats(df: pl.DataFrame) -> QualityMetrics:
    """
    Calculate base quality statistics from a Polars DataFrame.

    Args:
        df (pl.DataFrame): Input DataFrame containing quality metrics with column 'quals'

    Returns:
        QualityMetrics: Object containing base quality statistics

    Raises:
        Exception: If there's an error calculating statistics from the DataFrame
        KeyError: If required column is missing from the DataFrame
    """
    try:
        return QualityMetrics(
            max=df.select(pl.col("quals").max()).item(),
            min=df.select(pl.col("quals").min()).item(),
            mean=df.select(pl.col("quals").mean()).item(),
            std=df.select(pl.col("quals").std()).item(),
            median=df.select(pl.col("quals").median()).item(),
        )
    except Exception as e:
        logger.error(f"Error calculating quality statistics: {str(e)}")
        raise


def _print_quality_stats(stats: QualityMetrics, sample_type: str) -> None:
    """
    Print formatted base quality statistics.

    Args:
        stats (QualityMetrics): QualityMetrics object containing statistics to print
        sample_type (str): Type of sample (Multiplexed/Singleplexed)

    Raises:
        Exception: If there's an error formatting or printing the statistics
    """
    try:
        logger.info(f"Printing quality statistics for {sample_type} samples")
        print(f"\n{sample_type} Samples Statistics:")
        print("=" * 40)
        print("\nBase Qualities:")
        for stat_name, value in vars(stats).items():
            print(f"  {stat_name.capitalize():6s}: {value:.2f}")
    except Exception as e:
        logger.error(f"Error printing quality statistics: {str(e)}")
        raise


def analyze_basecall_quality(df: pl.DataFrame) -> None:
    """
    Analyze and print base quality statistics for multiplexed and singleplexed samples.

    Args:
        df (pl.DataFrame): Input DataFrame containing quality metrics with columns:
            - multiplexing: str ('singleplex' or 'multiplex')
            - basecall: str ('sup' or other)
            - quals: int/float

    Raises:
        Exception: If there's an error during analysis
        ValueError: If required data is missing from the DataFrame
    """
    try:
        singleplex_quals = df.filter(
            (pl.col("multiplexing") == "singleplex") & (pl.col("basecall") == "sup")
        )
        multiplex_quals = df.filter(
            (pl.col("multiplexing") == "multiplex") & (pl.col("basecall") == "sup")
        )

        if singleplex_quals.height == 0 or multiplex_quals.height == 0:
            logger.warning("No data found for either singleplex or multiplex samples")
            return

        singleplex_stats = _calculate_quality_stats(singleplex_quals)
        multiplex_stats = _calculate_quality_stats(multiplex_quals)

        _print_quality_stats(singleplex_stats, "Singleplexed")
        _print_quality_stats(multiplex_stats, "Multiplexed")

        print("\nPercentage Increase (Singleplexed vs Multiplexed):")
        print("=" * 40)

        for stat_name in ["mean", "median"]:
            increase = _calculate_percentage_increase(
                getattr(singleplex_stats, stat_name),
                getattr(multiplex_stats, stat_name),
            )
            print(f"{stat_name.capitalize():6s} Base Quality: {increase:6.2f}%")

        logger.info("Quality analysis completed successfully")

    except Exception as e:
        logger.error(f"Error in quality analysis: {str(e)}")
        raise


analyze_basecall_quality(processed_aligned_nanoplot_df)
__main__ - INFO - Printing quality statistics for Singleplexed samples
__main__ - INFO - Printing quality statistics for Multiplexed samples
__main__ - INFO - Quality analysis completed successfully
Singleplexed Samples Statistics:
========================================

Base Qualities:
  Max   : 49.64
  Min   : 1.88
  Mean  : 20.64
  Std   : 4.70
  Median: 20.60

Multiplexed Samples Statistics:
========================================

Base Qualities:
  Max   : 49.65
  Min   : 1.98
  Mean  : 20.45
  Std   : 4.69
  Median: 20.34

Percentage Increase (Singleplexed vs Multiplexed):
========================================
Mean   Base Quality:   0.95%
Median Base Quality:   1.24%

2. Mapping Quality¶

In [13]:
def calculate_mapping_quality_distribution(
    processed_data: pl.DataFrame, basecall_type: str = "sup"
) -> pl.DataFrame:
    """
    Calculate distribution of mapping quality scores across samples.

    Args:
        processed_data (pl.DataFrame): Processed NanoPlot data
        basecall_type (str, optional): Basecall type to filter. Defaults to "sup"

    Returns:
        pl.DataFrame: Mapping quality distribution statistics

    Raises:
        ValueError: If required columns are missing
    """
    try:
        filtered_data = processed_data.filter(
            pl.col("basecall") == basecall_type
        ).with_columns(pl.col("mapQ"), pl.col("number_of_reads"))

        max_mapq = filtered_data.select(pl.col("mapQ").max()).item()
        bins = np.arange(0, max_mapq + 0.5, 0.5)

        mapping_dist = (
            filtered_data.with_columns([pl.col("mapQ").cut(bins).alias("mapQ_bin")])
            .group_by(["anonymised_sample", "mapQ_bin", "multiplexing"])
            .agg(pl.len().alias("count"))
            .join(
                filtered_data.group_by("anonymised_sample").agg(
                    pl.first("number_of_reads").alias("number_of_reads")
                ),
                on="anonymised_sample",
            )
        )

        # Extract lower boundary of each bin
        bin_categories = mapping_dist.select(pl.col("mapQ_bin").unique()).to_series()
        bin_edges = [float(edge.split(",")[0][1:]) for edge in bin_categories]

        mapping_df = pl.DataFrame(
            {
                "mapQ_bin": bin_categories,
                "mapQ_bin_lower": bin_edges,
            }
        )

        mapping_dist = mapping_dist.with_columns(
            [(pl.col("count") / pl.col("number_of_reads") * 100).alias("percentage")]
        ).join(mapping_df, on="mapQ_bin", how="left")

        return mapping_dist

    except Exception as e:
        logger.error(f"Error calculating mapping quality distribution: {str(e)}")
        raise

    except Exception as e:
        logger.error(f"Error calculating mapping quality distribution: {str(e)}")
        raise


def plot_mapping_quality_distribution(
    mapping_dist: pl.DataFrame,
    figsize: Tuple[int, int] = (14, 6),
    dpi: int = 300,
    line_alpha: float = 0.8,
    gs: Optional[gridspec.GridSpec] = None,
) -> Optional[plt.Figure]:
    """
    Plot the distribution of mapping quality scores across samples.

    Args:
        mapping_dist (pl.DataFrame): DataFrame containing mapping quality distribution data
        figsize (Tuple[int, int], optional): Figure size. Defaults to (14, 6).
        dpi (int, optional): Figure DPI. Defaults to 300.
        line_alpha (float, optional): Line transparency. Defaults to 0.8.
        gs (gridspec.GridSpec, optional): GridSpec for plotting within a larger figure.

    Returns:
        Optional[Figure]: Figure object if created independently.

    Raises:
        ValueError: If required columns are missing
    """
    try:
        required_cols = {
            "anonymised_sample",
            "mapQ_bin_lower",
            "percentage",
            "multiplexing",
        }
        if not all(col in mapping_dist.columns for col in required_cols):
            missing = required_cols - set(mapping_dist.columns)
            raise ValueError(f"Missing required columns: {missing}")

        if gs is None:
            fig, ax = plt.subplots(figsize=figsize, dpi=dpi)
        else:
            fig = plt.gcf()
            ax = fig.add_subplot(gs[0, 0])

        plot_data = (
            mapping_dist.with_columns(
                [
                    pl.col("anonymised_sample")
                    .str.extract(r"(\d+)")
                    .cast(pl.Int32)
                    .alias("sample_num")
                ]
            )
            .sort("sample_num")
            .rename(
                {
                    "multiplexing": r"$\mathbf{Multiplexing}$",
                    "anonymised_sample": r"$\mathbf{Sample}$",
                }
            )
        )

        sns.lineplot(
            data=plot_data,
            x="mapQ_bin_lower",
            y="percentage",
            hue=r"$\mathbf{Sample}$",
            style=r"$\mathbf{Multiplexing}$",
            alpha=line_alpha,
            ax=ax,
        )

        ax.legend(loc="upper right", bbox_to_anchor=(1.02, 1.1))
        ax.set_xlabel("Mapping Quality Score")
        ax.set_ylabel("Proportion of Reads (%)")
        ax.set_title("Distribution of Mapping Quality Scores")

        max_mapq = int(plot_data["mapQ_bin_lower"].max())
        tick_positions = np.arange(0, max_mapq + 1, 5)
        ax.set_xticks(tick_positions)
        ax.set_xticklabels(tick_positions)

        if gs is None:
            plt.tight_layout()
            return fig
        else:
            ax.legend(bbox_to_anchor=(1.05, 1.05), loc="upper left")
            return None

    except Exception as e:
        logger.error(f"Error plotting mapping quality distribution: {str(e)}")
        raise


mapping_quality_distribution_df = calculate_mapping_quality_distribution(
    processed_aligned_nanoplot_df
)

mapping_quality_dist_plot = plot_mapping_quality_distribution(
    mapping_quality_distribution_df
)
No description has been provided for this image
In [14]:
@dataclass
class MappingMetrics:
    """
    Data class for storing mapping quality metrics statistics.
    """

    max: float
    min: float
    mean: float
    std: float
    median: float


def _calculate_mapping_stats(df: pl.DataFrame) -> MappingMetrics:
    """
    Calculate mapping quality statistics from a Polars DataFrame.

    Args:
        df (pl.DataFrame): Input DataFrame containing mapping quality metrics with column 'mapQ'

    Returns:
        MappingMetrics: Object containing mapping quality statistics

    Raises:
        Exception: If there's an error calculating statistics from the DataFrame
        KeyError: If required column is missing from the DataFrame
    """
    try:
        return MappingMetrics(
            max=df.select(pl.col("mapQ").max()).item(),
            min=df.select(pl.col("mapQ").min()).item(),
            mean=df.select(pl.col("mapQ").mean()).item(),
            std=df.select(pl.col("mapQ").std()).item(),
            median=df.select(pl.col("mapQ").median()).item(),
        )
    except Exception as e:
        logger.error(f"Error calculating mapping statistics: {str(e)}")
        raise


def _print_mapping_stats(stats: MappingMetrics, sample_type: str) -> None:
    """
    Print formatted mapping quality statistics.

    Args:
        stats (MappingMetrics): MappingMetrics object containing statistics to print
        sample_type (str): Type of sample (Multiplexed/Singleplexed)

    Raises:
        Exception: If there's an error formatting or printing the statistics
    """
    try:
        logger.info(f"Printing mapping statistics for {sample_type} samples")
        print(f"\n{sample_type} Samples Statistics:")
        print("=" * 40)
        print("\nMapping Quality:")
        for stat_name, value in vars(stats).items():
            print(f"  {stat_name.capitalize():6s}: {value:.2f}")
    except Exception as e:
        logger.error(f"Error printing mapping statistics: {str(e)}")
        raise


def analyze_mapping_quality(df: pl.DataFrame) -> None:
    """
    Analyze and print mapping quality statistics for multiplexed and singleplexed samples.

    Args:
        df (pl.DataFrame): Input DataFrame containing mapping metrics with columns:
            - multiplexing: str ('singleplex' or 'multiplex')
            - basecall: str ('sup' or other)
            - mapQ: int/float

    Raises:
        Exception: If there's an error during analysis
        ValueError: If required data is missing from the DataFrame
    """
    try:
        singleplex_quals = df.filter(
            (pl.col("multiplexing") == "singleplex") & (pl.col("basecall") == "sup")
        )
        multiplex_quals = df.filter(
            (pl.col("multiplexing") == "multiplex") & (pl.col("basecall") == "sup")
        )

        if singleplex_quals.height == 0 or multiplex_quals.height == 0:
            logger.warning("No data found for either singleplex or multiplex samples")
            return

        singleplex_stats = _calculate_mapping_stats(singleplex_quals)
        multiplex_stats = _calculate_mapping_stats(multiplex_quals)

        _print_mapping_stats(singleplex_stats, "Singleplexed")
        _print_mapping_stats(multiplex_stats, "Multiplexed")

        print("\nPercentage Increase (Singleplexed vs Multiplexed):")
        print("=" * 40)

        for stat_name in ["mean", "median"]:
            increase = _calculate_percentage_increase(
                getattr(singleplex_stats, stat_name),
                getattr(multiplex_stats, stat_name),
            )
            print(f"{stat_name.capitalize():6s} Mapping Quality: {increase:6.2f}%")

        logger.info("Mapping analysis completed successfully")

    except Exception as e:
        logger.error(f"Error in mapping analysis: {str(e)}")
        raise


analyze_mapping_quality(processed_aligned_nanoplot_df)
__main__ - INFO - Printing mapping statistics for Singleplexed samples
__main__ - INFO - Printing mapping statistics for Multiplexed samples
__main__ - INFO - Mapping analysis completed successfully
Singleplexed Samples Statistics:
========================================

Mapping Quality:
  Max   : 60.00
  Min   : 0.00
  Mean  : 56.71
  Std   : 12.67
  Median: 60.00

Multiplexed Samples Statistics:
========================================

Mapping Quality:
  Max   : 60.00
  Min   : 0.00
  Mean  : 56.68
  Std   : 12.75
  Median: 60.00

Percentage Increase (Singleplexed vs Multiplexed):
========================================
Mean   Mapping Quality:   0.05%
Median Mapping Quality:   0.00%

3. Combined Plots¶

In [15]:
def create_combined_quality_plot(
    quality_distribution: pl.DataFrame,
    mapping_quality_distribution: pl.DataFrame,
    qscore_df: pl.DataFrame,
    figsize: Tuple[int, int] = (12, 8),
    dpi: int = 300,
) -> plt.Figure:
    try:
        fig = plt.figure(figsize=figsize, dpi=dpi)
        gs = fig.add_gridspec(2, 2)

        # Plot base quality distribution (A)
        plot_base_quality_distribution(
            quality_distribution,
            gs=gridspec.GridSpecFromSubplotSpec(1, 1, subplot_spec=gs[0, 0]),
        )

        # Plot mapping quality distribution (B)
        plot_mapping_quality_distribution(
            mapping_quality_distribution,
            gs=gridspec.GridSpecFromSubplotSpec(1, 1, subplot_spec=gs[0, 1]),
        )

        # Plot QScore percentages (C and D)
        plot_qscore_percentage(
            qscore_df, gs=gridspec.GridSpecFromSubplotSpec(1, 2, subplot_spec=gs[1, :])
        )

        # Add panel labels
        for i, label in enumerate(["A", "B", "C", "D"]):
            ax = fig.axes[i]
            ax.text(
                -0.1,
                1.05,
                label,
                transform=ax.transAxes,
                fontsize=12,
                fontweight="bold",
                va="top",
            )

        fig.set_constrained_layout(True)
        return fig

    except Exception as e:
        logger.error(f"Error creating combined quality plot: {str(e)}")
        raise


combined_quality_plot = create_combined_quality_plot(
    base_quality_distribution_df, mapping_quality_distribution_df, qscore_percentage_df
)
No description has been provided for this image

Sequencing Depth¶

1. Depth per Chromosome¶

In [16]:
def process_mosdepth_file(file_path: Path, suffix: str) -> pl.DataFrame:
    """Process a mosdepth summary file and return relevant depth statistics.

    Args:
        file_path: Path to the mosdepth summary file
        suffix: Suffix to remove from sample names

    Returns:
        DataFrame containing processed depth statistics with columns: chrom, mean, sample
    """
    # sample_name will have the suffix removed.
    sample_name = file_path.name.split(".")[0].replace(suffix, "")
    df = pl.read_csv(file_path, separator="\t")
    chromosomes = [f"chr{i}" for i in range(1, 23)] + ["chrX", "chrY"]
    df = df.filter(
        (pl.col("chrom").is_in(chromosomes + ["total"]))
        & ~pl.col("chrom").str.ends_with("_region")
    )
    df = df.select(["chrom", "mean"])
    df = df.with_columns(pl.lit(sample_name).alias("sample"))
    return df


def process_per_base_file(file_path: Path, suffix: str) -> pl.DataFrame:
    """Process a mosdepth per-base file and calculate statistics per chromosome.

    Args:
        file_path: Path to the per-base depth file
        suffix: Suffix to remove from sample names

    Returns:
        DataFrame with per-chromosome statistics including mean depth and standard error
    """
    # sample_name will have the suffix removed.
    sample_name = file_path.name.split(".")[0].replace(suffix, "")
    df = pl.read_csv(
        file_path,
        separator="\t",
        has_header=False,
        new_columns=["chrom", "start", "end", "depth"],
    )
    chromosomes = [f"chr{i}" for i in range(1, 23)] + ["chrX", "chrY"]
    return (
        df.filter(pl.col("chrom").is_in(chromosomes))
        .group_by("chrom")
        .agg(
            [
                pl.col("depth").mean().alias("mean"),
                (pl.col("depth").std() / pl.col("depth").count().sqrt()).alias("sem"),
            ]
        )
        .with_columns(pl.lit(sample_name).alias("sample"))
    )


def analyze_mosdepth_data(
    metrics_df: pl.DataFrame,
    summary_files: List[Path],
    per_base_files: List[Path],
) -> Tuple[pl.DataFrame, pl.DataFrame]:
    """Analyze mosdepth data from summary and per-base files.

    Args:
        metrics_df: DataFrame containing sample metrics and metadata
        summary_files: List of paths to mosdepth summary files
        per_base_files: List of paths to mosdepth per-base files

    Returns:
        Tuple containing:
            - DataFrame with per-chromosome depth statistics
            - DataFrame with total depth statistics

    Raises:
        FileNotFoundError: If no mosdepth files are found
    """
    try:
        if not summary_files or not per_base_files:
            raise FileNotFoundError("No mosdepth files found")

        logger.info(f"Processing {len(summary_files)} mosdepth summary files")
        all_dfs = [
            process_mosdepth_file(file, basecall_suffix) for file in summary_files
        ]
        depth_df = pl.concat(all_dfs)

        logger.info(f"Processing {len(per_base_files)} per-base files")
        all_per_base_dfs = [
            process_per_base_file(file, basecall_suffix) for file in per_base_files
        ]
        per_base_df = pl.concat(all_per_base_dfs)

        # Join per-base stats (using both "chrom" and "sample" so that the right rows merge)
        depth_df = depth_df.join(
            per_base_df.rename({"mean": "per_base_mean", "sem": "per_base_sem"}),
            on=["chrom", "sample"],
            how="left",
        )

        total_depth_df = (
            depth_df.filter(pl.col("chrom") == "total")
            .unique(subset="sample")
            .select(["sample", "mean"])
        )

        total_depth_df = (
            total_depth_df.rename({"mean": "mean_depth"})
            .join(
                metrics_df.select(["sample", "multiplexing", "anonymised_sample"]),
                on="sample",
            )
            .sort(["multiplexing", "sample"])
        )

        depth_df = depth_df.filter(pl.col("chrom") != "total")
        chromosome_order = [f"chr{i}" for i in range(1, 23)] + ["chrX", "chrY"]
        depth_df = depth_df.with_columns(pl.col("chrom").cast(pl.Categorical))

        depth_df = depth_df.join(
            metrics_df.select(["sample", "multiplexing", "anonymised_sample"]),
            on="sample",
            how="left",
        )

        # Create sample_num by extracting the digit from the anonymised sample name
        depth_df = depth_df.with_columns(
            pl.col("anonymised_sample")
            .str.extract(r"(\d+)")
            .cast(pl.Int32)
            .alias("sample_num")
        ).sort("sample_num")

        wg_depth_df = (
            depth_df.filter(pl.col("chrom").is_in(chromosome_order))
            .unique(subset=["chrom", "anonymised_sample"])
            .sort(["anonymised_sample", "multiplexing"])
        )

        logger.info(
            f"Successfully processed depth data for {wg_depth_df.get_column('anonymised_sample').n_unique()} samples"
        )
        return wg_depth_df, total_depth_df

    except Exception as e:
        logger.error(f"Error analyzing mosdepth data: {str(e)}")
        raise


def plot_mean_depth_per_chromosome(
    wg_depth_df: pl.DataFrame,
    figsize: Tuple[int, int] = (14, 6),
    dpi: int = 300,
    line_alpha: float = 0.95,
    gs: Optional[gridspec.GridSpec] = None,
) -> Optional[plt.Figure]:
    """
    Plot mean depth per chromosome with standard error of the mean.

    Args:
        wg_depth_df (pl.DataFrame): DataFrame containing whole-genome depth statistics
        figsize (Tuple[int, int], optional): Figure size. Defaults to (14, 6)
        dpi (int, optional): Figure DPI. Defaults to 300
        line_alpha (float, optional): Line transparency. Defaults to 0.8
        gs (gridspec.GridSpec, optional): GridSpec for plotting within a larger figure

    Returns:
        Optional[plt.Figure]: Figure object if created independently

    Raises:
        ValueError: If required columns are missing
    """
    try:
        required_cols = {
            "anonymised_sample",
            "chrom",
            "mean",
            "per_base_sem",
            "multiplexing",
        }
        if not all(col in wg_depth_df.columns for col in required_cols):
            missing = required_cols - set(wg_depth_df.columns)
            raise ValueError(f"Missing required columns: {missing}")

        if gs is None:
            fig, ax = plt.subplots(figsize=figsize, dpi=dpi)
        else:
            fig = plt.gcf()
            ax = fig.add_subplot(gs[0, 0])

        plot_df = (
            wg_depth_df.with_columns(
                pl.col("anonymised_sample")
                .str.extract(r"(\d+)")
                .cast(pl.Int32)
                .alias("sample_num")
            )
            .sort(["sample_num", "chrom"])
            .rename(
                {
                    "multiplexing": r"$\mathbf{Multiplexing}$",
                    "anonymised_sample": r"$\mathbf{Sample}$",
                }
            )
        )

        unique_samples = list(dict.fromkeys(plot_df[r"$\mathbf{Sample}$"].to_list()))
        color_palette = sns.color_palette("husl", n_colors=len(unique_samples))
        color_dict = dict(zip(unique_samples, color_palette))

        sns.lineplot(
            data=plot_df,
            x="chrom",
            y="mean",
            hue=r"$\mathbf{Sample}$",
            style=r"$\mathbf{Multiplexing}$",
            legend="full",
            palette=color_dict,
            hue_order=unique_samples,
            alpha=line_alpha,
            ax=ax,
        )

        for sample in unique_samples:
            sample_df = plot_df.filter(pl.col(r"$\mathbf{Sample}$") == sample)
            ax.fill_between(
                sample_df["chrom"],
                sample_df["mean"] - sample_df["per_base_sem"],
                sample_df["mean"] + sample_df["per_base_sem"],
                alpha=0.25,
                color=color_dict[sample],  # Match fill color to line color
            )

        ax.set_title("Mean Depth per Chromosome (with SEM)")
        ax.set_xlabel("Chromosome")
        ax.set_ylabel("Mean Depth")

        locs, labels = plt.xticks()
        ax.set_xticks([loc + 0.01 for loc in locs])
        ax.set_xticklabels(labels, rotation=45, ha="right")
        ax.grid(axis="y", linestyle="--", alpha=0.7)

        if gs is None:
            ax.legend(bbox_to_anchor=(1, 1), loc="upper left")
            plt.tight_layout()
            return fig
        else:
            ax.legend(bbox_to_anchor=(1.05, 1.05), loc="upper left")
            return None

    except Exception as e:
        logger.error(f"Error plotting mean depth per chromosome: {str(e)}")
        raise


basecall_suffix = "_sup"

mosdepth_summary_dir = Path("/scratch/prj/ppn_als_longread/ont-benchmark/qc/mosdepth/")

mosdepth_summary_files = list(
    mosdepth_summary_dir.glob(
        f"*{basecall_suffix}/*{basecall_suffix}.mosdepth.summary.txt"
    )
)

mosdepth_per_base_files = list(
    mosdepth_summary_dir.glob(f"*{basecall_suffix}/*{basecall_suffix}.per-base.bed.gz")
)

wg_depth_df, total_depth_df = analyze_mosdepth_data(
    metrics_df=nanoplot_qc_metrics_df,
    summary_files=mosdepth_summary_files,
    per_base_files=mosdepth_per_base_files,
)

mean_depth_chr_plot = plot_mean_depth_per_chromosome(wg_depth_df)
__main__ - INFO - Processing 14 mosdepth summary files
__main__ - INFO - Processing 14 per-base files
__main__ - INFO - Successfully processed depth data for 14 samples
No description has been provided for this image

2. Mean Whole Genome Depth¶

In [17]:
def plot_mean_whole_genome_depth(
    total_depth_df: pl.DataFrame,
    figsize: Tuple[int, int] = (14, 6),
    dpi: int = 300,
    gs: Optional[gridspec.GridSpec] = None,
) -> Optional[plt.Figure]:
    """Plot mean whole genome depth per sample.

    Args:
        total_depth_df: DataFrame containing total depth statistics
        figsize: Figure size. Defaults to (14, 6)
        dpi: Figure DPI. Defaults to 300
        gs: GridSpec for plotting within a larger figure

    Returns:
        Optional[plt.Figure]: Figure object if created independently

    Raises:
        ValueError: If required columns are missing
    """
    try:
        required_cols = {"anonymised_sample", "mean_depth", "multiplexing"}
        if not all(col in total_depth_df.columns for col in required_cols):
            missing = required_cols - set(total_depth_df.columns)
            raise ValueError(f"Missing required columns: {missing}")

        if gs is None:
            fig, ax = plt.subplots(figsize=figsize, dpi=dpi)
        else:
            fig = plt.gcf()
            ax = fig.add_subplot(gs[0, 0])

        plot_df = total_depth_df.with_columns(
            pl.col("anonymised_sample")
            .str.extract(r"(\d+)")
            .cast(pl.Int32)
            .alias("sample_num")
        ).sort("sample_num")

        sns.barplot(
            data=plot_df,
            x="anonymised_sample",
            y="mean_depth",
            hue="multiplexing",
            dodge=False,
            order=plot_df["anonymised_sample"].to_list(),
            ax=ax,
        )

        ax.set_title("Mean Whole Genome Depth per Sample")
        ax.set_xlabel("Sample")
        ax.set_ylabel("Depth")

        locs, labels = plt.xticks()
        ax.set_xticks([loc + 0.15 for loc in locs])
        ax.set_xticklabels(labels, rotation=45, ha="right")

        if gs is None:
            legend = ax.legend(
                bbox_to_anchor=(1, 1), loc="upper left", title="Multiplexing"
            )
            legend.get_title().set_weight("bold")
            plt.tight_layout()
            return fig
        else:
            legend = ax.legend(loc="upper left", title="Multiplexing")
            legend.get_title().set_weight("bold")
            return None

    except Exception as e:
        logger.error(f"Error plotting mean whole genome depth: {str(e)}")
        raise


mean_depth_wg_plot = plot_mean_whole_genome_depth(total_depth_df)
No description has been provided for this image
In [18]:
@dataclass
class DepthMetrics:
    """
    Data class for storing sequencing depth metrics statistics.
    """

    max: float
    min: float
    mean: float
    std: float
    median: float


def _calculate_depth_stats(df: pl.DataFrame, column_name: str) -> DepthMetrics:
    """
    Calculate depth statistics from a Polars DataFrame.

    Args:
        df (pl.DataFrame): Input DataFrame containing depth metrics
        column_name (str): Name of the column containing depth values

    Returns:
        DepthMetrics: Object containing depth statistics

    Raises:
        Exception: If there's an error calculating statistics from the DataFrame
        KeyError: If required column is missing from the DataFrame
    """
    try:
        return DepthMetrics(
            max=df.select(pl.col(column_name).max()).item(),
            min=df.select(pl.col(column_name).min()).item(),
            mean=df.select(pl.col(column_name).mean()).item(),
            std=df.select(pl.col(column_name).std()).item(),
            median=df.select(pl.col(column_name).median()).item(),
        )
    except Exception as e:
        logger.error(f"Error calculating depth statistics: {str(e)}")
        raise


def _print_depth_stats(
    stats: DepthMetrics, sample_type: str, depth_type: str = "per-chromosome"
) -> None:
    """
    Print formatted depth statistics.

    Args:
        stats (DepthMetrics): DepthMetrics object containing statistics to print
        sample_type (str): Type of sample (Multiplexed/Singleplexed)
        depth_type (str): Type of depth calculation ("per-chromosome" or "whole-genome")

    Raises:
        Exception: If there's an error formatting or printing the statistics
    """
    try:
        logger.info(f"Printing {depth_type} depth statistics for {sample_type} samples")
        print(f"\n{sample_type} Samples Statistics:")
        print("=" * 40)
        print("\nDepth:")
        for stat_name, value in vars(stats).items():
            formatted_value = f"{value:.2f}"
            print(f"  {stat_name.capitalize():6s}: {formatted_value}")
    except Exception as e:
        logger.error(f"Error printing depth statistics: {str(e)}")
        raise


def analyze_sequencing_depth(
    wg_depth_df: pl.DataFrame, total_depth_df: pl.DataFrame
) -> None:
    """
    Analyze and print sequencing depth statistics for multiplexed and singleplexed samples.

    Args:
        wg_depth_df (pl.DataFrame): DataFrame containing per-chromosome depth metrics
        total_depth_df (pl.DataFrame): DataFrame containing whole genome depth metrics

    Raises:
        Exception: If there's an error during analysis
        ValueError: If required data is missing from the DataFrame
    """
    try:
        # Per-chromosome depth analysis
        singleplexed_depth = wg_depth_df.filter(pl.col("multiplexing") == "singleplex")
        multiplexed_depth = wg_depth_df.filter(pl.col("multiplexing") == "multiplex")

        if singleplexed_depth.height == 0 or multiplexed_depth.height == 0:
            logger.warning("No data found for either singleplex or multiplex samples")
            return

        print("\nPer-Chromosome Depth Statistics:")
        singleplex_depth_stats = _calculate_depth_stats(singleplexed_depth, "mean")
        multiplex_depth_stats = _calculate_depth_stats(multiplexed_depth, "mean")

        _print_depth_stats(singleplex_depth_stats, "Singleplexed", "per-chromosome")
        _print_depth_stats(multiplex_depth_stats, "Multiplexed", "per-chromosome")

        # Whole genome depth analysis
        singleplexed_wg = total_depth_df.filter(pl.col("multiplexing") == "singleplex")
        multiplexed_wg = total_depth_df.filter(pl.col("multiplexing") == "multiplex")

        print("\nWhole Genome Depth Statistics:")
        singleplex_wg_stats = _calculate_depth_stats(singleplexed_wg, "mean_depth")
        multiplex_wg_stats = _calculate_depth_stats(multiplexed_wg, "mean_depth")

        _print_depth_stats(singleplex_wg_stats, "Singleplexed", "whole-genome")
        _print_depth_stats(multiplex_wg_stats, "Multiplexed", "whole-genome")

        print("\nPercentage Increase (Singleplexed vs Multiplexed):")
        print("=" * 40)
        for stat_name in ["mean", "median"]:
            wg_increase = _calculate_percentage_increase(
                getattr(singleplex_wg_stats, stat_name),
                getattr(multiplex_wg_stats, stat_name),
            )
            print(f"{stat_name.capitalize():6s} Depth: {wg_increase:6.2f}%")

    except Exception as e:
        logger.error(f"Error in depth analysis: {str(e)}")
        raise


analyze_sequencing_depth(wg_depth_df, total_depth_df)
__main__ - INFO - Printing per-chromosome depth statistics for Singleplexed samples
__main__ - INFO - Printing per-chromosome depth statistics for Multiplexed samples
__main__ - INFO - Printing whole-genome depth statistics for Singleplexed samples
__main__ - INFO - Printing whole-genome depth statistics for Multiplexed samples
Per-Chromosome Depth Statistics:

Singleplexed Samples Statistics:
========================================

Depth:
  Max   : 31.15
  Min   : 1.23
  Mean  : 19.01
  Std   : 5.98
  Median: 17.45

Multiplexed Samples Statistics:
========================================

Depth:
  Max   : 13.54
  Min   : 0.44
  Mean  : 9.45
  Std   : 2.33
  Median: 9.79

Whole Genome Depth Statistics:

Singleplexed Samples Statistics:
========================================

Depth:
  Max   : 27.80
  Min   : 14.42
  Mean  : 19.84
  Std   : 5.22
  Median: 17.20

Multiplexed Samples Statistics:
========================================

Depth:
  Max   : 12.11
  Min   : 7.86
  Mean  : 9.84
  Std   : 1.58
  Median: 9.93

Percentage Increase (Singleplexed vs Multiplexed):
========================================
Mean   Depth: 101.63%
Median Depth:  73.30%

3. Flowcell Quality¶

In [19]:
def read_flowcell_stats(file_path: Path) -> pl.DataFrame:
    """
    Read and process flowcell statistics from CSV file.

    Args:
        file_path: Path to the CSV file containing flowcell statistics

    Returns:
        pl.DataFrame: DataFrame containing processed flowcell statistics

    Raises:
        FileNotFoundError: If input file doesn't exist
        ValueError: If required columns are missing
    """
    try:
        if not file_path.exists():
            raise FileNotFoundError(f"Stats file not found: {file_path}")

        df = pl.read_csv(file_path)

        required_cols = {"flowcell_id", "number_pores_start"}
        if not all(col in df.columns for col in required_cols):
            missing = required_cols - set(df.columns)
            raise ValueError(f"Missing required columns: {missing}")

        df = df.with_columns(
            [
                pl.col("flowcell_id")
                .cast(pl.Utf8)
                .str.contains("__")
                .map_elements(
                    lambda x: "multiplex" if x else "singleplex", return_dtype=pl.Utf8
                )
                .alias("multiplexing")
            ]
        )

        # Create anonymous flowcell names
        multiplex_df = (
            df.filter(pl.col("multiplexing") == "multiplex")
            .with_row_index("index")
            .with_columns(
                pl.col("index")
                .add(1)
                .map_elements(lambda x: f"Multiplex Flowcell {x}", return_dtype=pl.Utf8)
                .alias("new_flowcell_name")
            )
        )

        singleplex_df = (
            df.filter(pl.col("multiplexing") == "singleplex")
            .with_row_index("index")
            .with_columns(
                pl.col("index")
                .add(1)
                .map_elements(
                    lambda x: f"Singleplex Flowcell {x}", return_dtype=pl.Utf8
                )
                .alias("new_flowcell_name")
            )
        )

        return pl.concat([multiplex_df, singleplex_df]).sort(
            ["multiplexing", "flowcell_id"]
        )

    except Exception as e:
        logger.error(f"Error reading flowcell stats: {str(e)}")
        raise


def plot_flowcell_pores(
    df: pl.DataFrame,
    figsize: Tuple[int, int] = (14, 6),
    dpi: int = 300,
    marker_size: int = 8,
    line_alpha: float = 0.8,
    gs: Optional[gridspec.GridSpec] = None,
) -> plt.Figure:
    """
    Plot the number of pores available at start across flowcells.

    Args:
        df (pl.DataFrame): DataFrame containing flowcell statistics
        figsize (Tuple[int, int], optional): Figure size. Defaults to (14, 6).
        dpi (int, optional): Figure DPI. Defaults to 300.
        marker_size (int, optional): Size of markers. Defaults to 8.
        line_alpha (float, optional): Line transparency. Defaults to 0.8.

    Returns:
        plt.Figure: Matplotlib figure object

    Raises:
        ValueError: If required columns are missing
    """
    try:
        required_cols = {"new_flowcell_name", "number_pores_start", "multiplexing"}
        if not all(col in df.columns for col in required_cols):
            missing = required_cols - set(df.columns)
            raise ValueError(f"Missing required columns: {missing}")

        if gs is None:
            fig, ax = plt.subplots(figsize=figsize, dpi=dpi)
        else:
            fig = plt.gcf()
            ax = fig.add_subplot(gs[0, 0])

        sns.lineplot(
            data=df,
            x="new_flowcell_name",
            y="number_pores_start",
            hue="multiplexing",
            marker="o",
            style="multiplexing",
            alpha=line_alpha,
            markersize=marker_size,
            ax=ax,
        )

        ax.set_title("Number of Pores Available at Start of Sequencing")
        ax.set_xlabel("Flowcell ID")
        ax.set_ylabel("Number of Pores")
        plt.xticks(rotation=45, ha="right")
        ax.set_ylim(bottom=0)

        if gs is None:
            legend = ax.legend(
                bbox_to_anchor=(1, 1), loc="upper left", title="Multiplexing"
            )
            legend.get_title().set_weight("bold")
            plt.tight_layout()
            return fig
        else:
            legend = ax.legend(loc="lower right", title="Multiplexing")
            legend.get_title().set_weight("bold")
            return None

    except Exception as e:
        logger.error(f"Error plotting flowcell pores: {str(e)}")
        raise


seq_stats_path = Path("/scratch/prj/ppn_als_longread/ont-benchmark/seq_stats.csv")
flowcell_stats_df = read_flowcell_stats(seq_stats_path)
flowcell_plot = plot_flowcell_pores(flowcell_stats_df)
No description has been provided for this image
In [20]:
@dataclass
class PoreMetrics:
    """
    Data class for storing flowcell pore statistics.
    """

    max: float
    min: float
    mean: float
    std: float
    median: float


def _calculate_pore_stats(df: pl.DataFrame, multiplexing_type: str) -> PoreMetrics:
    """
    Calculate pore statistics from a Polars DataFrame.

    Args:
        df: Input DataFrame containing pore metrics
        multiplexing_type: Type of multiplexing to filter by

    Returns:
        PoreMetrics: Object containing pore statistics

    Raises:
        ValueError: If required columns are missing
        Exception: If there's an error calculating statistics
    """
    try:
        required_cols = {"multiplexing", "number_pores_start"}
        if not all(col in df.columns for col in required_cols):
            missing = required_cols - set(df.columns)
            raise ValueError(f"Missing required columns: {missing}")

        subset = df.filter(pl.col("multiplexing") == multiplexing_type)

        return PoreMetrics(
            max=subset.select(pl.col("number_pores_start").max()).item(),
            min=subset.select(pl.col("number_pores_start").min()).item(),
            mean=subset.select(pl.col("number_pores_start").mean()).item(),
            std=subset.select(pl.col("number_pores_start").std()).item(),
            median=subset.select(pl.col("number_pores_start").median()).item(),
        )
    except Exception as e:
        logger.error(f"Error calculating pore statistics: {str(e)}")
        raise


def _print_pore_stats(stats: PoreMetrics, sample_type: str) -> None:
    """
    Print formatted pore statistics.

    Args:
        stats: PoreMetrics object containing statistics to print
        sample_type: Type of sample (Multiplexed/Singleplexed)

    Raises:
        Exception: If there's an error formatting or printing the statistics
    """
    try:
        logger.info(f"Printing pore statistics for {sample_type} flowcells")
        print(f"\n{sample_type} Flowcells Statistics:")
        print("=" * 40)
        print("\nNumber of Pores at Start:")
        for stat_name, value in vars(stats).items():
            formatted_value = f"{value:.2f}"
            print(f"  {stat_name.capitalize():6s}: {formatted_value}")
    except Exception as e:
        logger.error(f"Error printing pore statistics: {str(e)}")
        raise


def analyze_flowcell_pores(df: pl.DataFrame) -> None:
    """
    Analyze and print flowcell pore statistics for multiplexed and singleplexed samples.

    Args:
        df: DataFrame containing flowcell pore metrics

    Raises:
        Exception: If there's an error during analysis
        ValueError: If required data is missing from the DataFrame
    """
    try:
        singleplex_stats = _calculate_pore_stats(df, "singleplex")
        multiplex_stats = _calculate_pore_stats(df, "multiplex")

        _print_pore_stats(singleplex_stats, "Singleplexed")
        _print_pore_stats(multiplex_stats, "Multiplexed")

        print("\nPercentage Increase (Singleplexed vs Multiplexed):")
        print("=" * 40)
        for stat_name in ["mean", "median"]:
            increase = _calculate_percentage_increase(
                getattr(singleplex_stats, stat_name),
                getattr(multiplex_stats, stat_name),
            )
            print(f"{stat_name.capitalize():6s} Number of Pores: {increase:6.2f}%")

    except Exception as e:
        logger.error(f"Error in pore analysis: {str(e)}")
        raise


analyze_flowcell_pores(flowcell_stats_df)
__main__ - INFO - Printing pore statistics for Singleplexed flowcells
__main__ - INFO - Printing pore statistics for Multiplexed flowcells
Singleplexed Flowcells Statistics:
========================================

Number of Pores at Start:
  Max   : 8152.00
  Min   : 4874.00
  Mean  : 7008.75
  Std   : 1190.38
  Median: 7422.50

Multiplexed Flowcells Statistics:
========================================

Number of Pores at Start:
  Max   : 8223.00
  Min   : 8024.00
  Mean  : 8116.00
  Std   : 100.34
  Median: 8101.00

Percentage Increase (Singleplexed vs Multiplexed):
========================================
Mean   Number of Pores: -13.64%
Median Number of Pores:  -8.38%

4. Relation between Flowcell Quality and Mean Whole Genome Depth¶

In [21]:
def parse_seq_stats_data(
    seq_stats_df: pl.DataFrame, depth_df: pl.DataFrame
) -> pl.DataFrame:
    """
    Parse sequencing statistics data and merge with depth information.

    Args:
        seq_stats_df: Polars DataFrame containing sequencing statistics
        depth_df: Polars DataFrame containing depth information

    Returns:
        pl.DataFrame: Merged DataFrame with correctly summed depths for multiplexed samples

    Raises:
        ValueError: If required columns are missing
    """
    required_seq_cols = {"flowcell_id", "multiplexing"}
    required_depth_cols = {"sample", "mean_depth"}

    if not all(col in seq_stats_df.columns for col in required_seq_cols):
        missing = required_seq_cols - set(seq_stats_df.columns)
        raise ValueError(f"Missing required columns in seq_stats_df: {missing}")

    if not all(col in depth_df.columns for col in required_depth_cols):
        missing = required_depth_cols - set(depth_df.columns)
        raise ValueError(f"Missing required columns in depth_df: {missing}")

    # Create a mapping DataFrame for multiplexed samples
    multiplexed_samples = (
        seq_stats_df.filter(pl.col("multiplexing") == "multiplex")
        .select("flowcell_id")
        .with_columns([pl.col("flowcell_id").str.split("__").alias("sample_ids")])
        .explode("sample_ids")
    )

    # Create mapping for singleplex samples
    singleplex_samples = (
        seq_stats_df.filter(pl.col("multiplexing") == "singleplex")
        .select("flowcell_id")
        .with_columns([pl.col("flowcell_id").alias("sample_ids")])
    )

    # Combine mappings
    all_samples = pl.concat([multiplexed_samples, singleplex_samples])

    # Join with depth information and aggregate
    merged_depths = (
        all_samples.join(depth_df, left_on="sample_ids", right_on="sample", how="left")
        .group_by("flowcell_id")
        .agg([pl.col("mean_depth").sum().alias("total_mean_depth")])
    )

    # Join back to original stats DataFrame
    return seq_stats_df.join(merged_depths, on="flowcell_id", how="left")


def plot_flowcell_depth_correlation(
    df: pl.DataFrame,
    figsize: Tuple[int, int] = (12, 8),
    dpi: int = 300,
    marker_size: int = 100,
    confidence_alpha: float = 0.2,
    gs: Optional[gridspec.GridSpec] = None,
) -> Tuple[plt.Figure, Tuple[float, float, float, float]]:
    """
    Plot correlation between number of pores and sequencing depth.

    Args:
        df: DataFrame containing flowcell statistics and depth data
        figsize: Figure size (width, height)
        dpi: Figure resolution
        marker_size: Size of scatter plot markers
        confidence_alpha: Transparency of confidence interval
        gs: Optional GridSpec for subplot placement

    Returns:
        Tuple containing:
            - plt.Figure: Matplotlib figure object
            - Tuple[float, float, float, float]: (slope, intercept, r_value, p_value)

    Raises:
        ValueError: If required columns are missing
    """
    required_cols = {"number_pores_start", "total_mean_depth", "multiplexing"}
    if not all(col in df.columns for col in required_cols):
        missing = required_cols - set(df.columns)
        raise ValueError(f"Missing required columns: {missing}")

    try:
        if gs is None:
            fig, ax = plt.subplots(figsize=figsize, dpi=dpi)
        else:
            fig = plt.gcf()
            ax = fig.add_subplot(gs[0, 0])

        # Create scatter plot
        sns.scatterplot(
            data=df,
            x="number_pores_start",
            y="total_mean_depth",
            hue="multiplexing",
            s=marker_size,
            ax=ax,
        )

        # Calculate regression
        x = df["number_pores_start"]
        y = df["total_mean_depth"]
        slope, intercept, r_value, p_value, std_err = stats.linregress(x, y)

        # Plot regression line and confidence interval
        x_line = np.linspace(x.min(), x.max(), 100)
        y_line = slope * x_line + intercept

        ax.plot(
            x_line,
            y_line,
            color="gray",
            linestyle="-",
            linewidth=2,
            label="line of best fit",
        )

        # Calculate confidence interval
        n = len(x)
        y_pred = slope * x + intercept
        s_err = np.sqrt((y - y_pred).pow(2).sum() / (n - 2))
        t = stats.t.ppf(0.975, n - 2)
        ci = (
            t
            * s_err
            * np.sqrt(1 / n + (x_line - x.mean()) ** 2 / ((x - x.mean()) ** 2).sum())
        )

        ax.fill_between(
            x_line,
            y_line - ci,
            y_line + ci,
            color="gray",
            alpha=confidence_alpha,
            label="95% Confidence Interval",
        )

        ax.set_title(
            f"Number of Pores at Start vs Total Mean Whole Genome Depth\n"
            f"r = {r_value:.2f}, p = {p_value:.2g}"
        )
        ax.set_xlabel("Number of Pores at Start")
        ax.set_ylabel("Mean Whole Genome Depth")

        if gs is None:
            legend = ax.legend(
                bbox_to_anchor=(1, 1), loc="upper left", title="Multiplexing"
            )
            legend.get_title().set_weight("bold")
            plt.tight_layout()
            return fig, (slope, intercept, r_value, p_value)
        else:
            legend = ax.legend(loc="lower right", title="Multiplexing")
            legend.get_title().set_weight("bold")
            return None

    except Exception as e:
        logger.error(f"Error plotting flowcell depth correlation: {str(e)}")
        raise


merged_stats_df = parse_seq_stats_data(flowcell_stats_df, total_depth_df)
fig, pores_depth_regression_results = plot_flowcell_depth_correlation(merged_stats_df)

slope, intercept, r_value, p_value = pores_depth_regression_results

logger.info(
    f"Regression statistics:\n"
    f"Slope: {slope:.4f}\n"
    f"Intercept: {intercept:.4f}\n"
    f"R-value: {r_value:.4f}\n"
    f"P-value: {p_value:.4e}"
)
__main__ - INFO - Regression statistics:
Slope: 0.0025
Intercept: 1.6324
R-value: 0.6387
P-value: 3.4417e-02
No description has been provided for this image

5. Barcoding Quality¶

In [22]:
def _get_sample_barcode_mapping() -> Dict[str, str]:
    """
    Get mapping between sample IDs and barcodes.

    Returns:
        Dict[str, str]: Mapping of sample IDs to barcodes
    """
    return {
        "A046_12": "barcode01",
        "A079_07": "barcode02",
        "A081_91": "barcode03",
        "A048_09": "barcode04",
        "A097_92": "barcode05",
        "A085_00": "barcode06",
    }


def _parse_nanostats_barcoded(file_path: Path) -> Dict[str, int]:
    """
    Parse NanoStats barcoded file.

    Args:
        file_path (Path): Path to NanoStats barcoded file

    Returns:
        Dict[str, int]: Mapping of barcodes to read counts
    """
    metrics = {}
    with open(file_path, "r") as f:
        header = f.readline().strip().split("\t")
        values = f.readline().strip().split("\t")

        for barcode, value in zip(header[1:], values[1:]):
            if barcode == "unclassified" or barcode.startswith("barcode"):
                metrics[barcode] = int(value)
    return metrics


def _parse_flowcell_samples(seq_summaries_dir: Path) -> Dict[str, List[str]]:
    """
    Parse flowcell samples from directory names.

    Args:
        seq_summaries_dir (Path): Directory containing sequencing summaries

    Returns:
        Dict[str, List[str]]: Mapping of flowcell names to sample lists
    """
    flowcell_samples = {}
    for subdir in Path(seq_summaries_dir).iterdir():
        if "__" in subdir.name:
            samples = subdir.name.split("__")
            flowcell_samples[subdir.name] = samples
    return flowcell_samples


def plot_multiplexed_flowcell_reads(
    seq_summaries_dir: Path,
    figsize: Tuple[int, int] = (12, 6),
    dpi: int = 300,
    bar_width: float = 0.25,
    gs: Optional[gridspec.GridSpec] = None,
) -> Optional[plt.Figure]:
    """
    Plot multiplexed flowcell reads distribution using Polars.

    Args:
        seq_summaries_dir (Path): Directory containing sequencing summaries
        figsize (Tuple[int, int]): Figure size
        dpi (int): Figure DPI
        bar_width (float): Width of bars in plot
        color_palette (str): Seaborn color palette name
        gs (gridspec.GridSpec, optional): GridSpec for plotting within a larger figure

    Returns:
        Optional[plt.Figure]: Figure object if created independently, None if using gridspec
    """
    try:
        sample_barcode_mapping = _get_sample_barcode_mapping()
        barcode_sample_mapping = {v: k for k, v in sample_barcode_mapping.items()}
        flowcell_samples = _parse_flowcell_samples(seq_summaries_dir)

        flowcell_rename = {
            name: f"Multiplex Flowcell {i+1}"
            for i, name in enumerate(flowcell_samples.keys())
        }

        unique_samples = sorted(set(sum(flowcell_samples.values(), [])))
        sample_rename = {
            sample: f"Sample {i+1 if i != 1 and i != 2 else 3 if i == 1 else 2}"
            for i, sample in enumerate(unique_samples)
        }
        sample_rename["Unclassified"] = "Unclassified"

        data = []
        for subdir in Path(seq_summaries_dir).iterdir():
            if "__" not in subdir.name:
                continue

            flowcell_name = subdir.name
            nanostats_path = subdir / "NanoStats_barcoded.txt"

            if not nanostats_path.exists():
                continue

            metrics = _parse_nanostats_barcoded(nanostats_path)

            for barcode, read_count in metrics.items():
                if barcode in barcode_sample_mapping:
                    sample = barcode_sample_mapping[barcode]
                    if sample in flowcell_samples[flowcell_name]:
                        data.append(
                            {
                                "Flowcell": flowcell_rename[flowcell_name],
                                "Sample": sample_rename[sample],
                                "Read Count": read_count,
                            }
                        )
                elif barcode == "unclassified":
                    data.append(
                        {
                            "Flowcell": flowcell_rename[flowcell_name],
                            "Sample": "Unclassified",
                            "Read Count": read_count,
                        }
                    )

        if not data:
            raise ValueError("No valid data found for plotting")

        df = pl.DataFrame(data)
        df = df.with_columns(
            [
                pl.col("Flowcell").cast(pl.Categorical),
                pl.col("Sample").cast(pl.Categorical),
            ]
        )

        df = df.sort(["Flowcell", "Sample"])

        flowcell_stats = (
            df.group_by("Flowcell")
            .agg(
                [
                    pl.col("Read Count").mean().alias("mean_reads"),
                    pl.col("Read Count").std().alias("std_reads"),
                ]
            )
            .with_columns(
                (pl.col("std_reads") / pl.col("mean_reads") * 100).alias(
                    "cv_percentage"
                )
            )
        )

        for row in flowcell_stats.iter_rows(named=True):
            print(
                f"{row['Flowcell']} - Coefficient of Variation: {row['cv_percentage']:.2f}%"
            )

        unclassified_stats = (
            df.group_by("Flowcell")
            .agg(
                pl.col("Read Count").sum().alias("total_reads"),
                pl.col("Read Count")
                .filter(pl.col("Sample") == "Unclassified")
                .alias("unclassified_reads"),
            )
            .with_columns(
                (pl.col("unclassified_reads") / pl.col("total_reads") * 100).alias(
                    "unclassified_percentage"
                )
            )
        )

        mean_unclassified = (
            unclassified_stats["unclassified_percentage"].explode().mean()
        )
        std_unclassified = unclassified_stats["unclassified_percentage"].explode().std()

        print(f"Unclassified reads: {mean_unclassified:.2f}% ± {std_unclassified:.2f}%")

        if gs is None:
            fig, ax = plt.subplots(figsize=figsize, dpi=dpi)
        else:
            fig = plt.gcf()
            ax = fig.add_subplot(gs[0, 0])

        color = sns.color_palette()[0]
        flowcells = df.select(pl.col("Flowcell").unique()).to_series()
        x = range(len(flowcells) * 3)

        bars = ax.bar(x, df.select("Read Count").to_series(), bar_width, color=color)

        y_max = ax.get_ylim()[1]  # Upper limit of y-axis
        if y_max > 0:
            scale = int(np.floor(np.log10(y_max)))  # Compute order of magnitude
            if (
                scale >= 3
            ):  # Only apply if the scale is meaningful (e.g., thousands or more)
                ax.set_ylabel(f"Number of reads ($1×10^{{{scale}}}$)")
            else:
                ax.set_ylabel("Number of reads")
        else:
            ax.set_ylabel("Number of reads")

        ax.set_title("Number of barcoded reads")
        ax.tick_params(
            axis="x", which="both", bottom=False, top=False, labelbottom=False
        )
        ax.xaxis.grid(False)

        for i, bar in enumerate(bars):
            sample = df.row(i)[1]  # Get Sample value
            ax.text(
                bar.get_x() + bar.get_width() / 2 - 0.3,
                -1.8e6,
                sample,
                ha="center",
                va="bottom",
                rotation=45,
            )

        for i, flowcell in enumerate(flowcells):
            ax.text(i * 3 + 1, -2.4e6, flowcell, ha="center")

        for i in range(1, len(flowcells)):
            ax.axvline(x=i * 3 - 0.5, color="gray", linestyle="-", linewidth=0.5)

        if gs is None:
            plt.tight_layout()
            return fig
        return None

    except Exception as e:
        logger.error(f"Error plotting multiplexed flowcell reads: {str(e)}")
        raise


multiplex_flowcell_reads_plot = plot_multiplexed_flowcell_reads(np_seq_summaries_dir)
Multiplex Flowcell 1 - Coefficient of Variation: 63.78%
Multiplex Flowcell 2 - Coefficient of Variation: 31.77%
Multiplex Flowcell 3 - Coefficient of Variation: 38.91%
Unclassified reads: 19.04% ± 3.15%
No description has been provided for this image

6. Combined Plots¶

In [23]:
def create_combined_sequencing_plot(
    wg_depth_df: pl.DataFrame,
    total_depth_df: pl.DataFrame,
    flowcell_stats_df: pl.DataFrame,
    merged_stats_df: pl.DataFrame,
    seq_summaries_dir: Path,
    figsize: Tuple[int, int] = (12, 16),
    dpi: int = 300,
) -> plt.Figure:
    try:
        # Create figure with GridSpec
        fig = plt.figure(figsize=figsize, dpi=dpi)
        gs = fig.add_gridspec(4, 2, height_ratios=[1, 1, 1, 1])

        # Plot A: Mean Depth per Chromosome (full width)
        plot_mean_depth_per_chromosome(
            wg_depth_df,
            gs=gridspec.GridSpecFromSubplotSpec(1, 1, subplot_spec=gs[0, :]),
        )

        # Plot B: Mean Whole Genome Depth per Sample
        plot_mean_whole_genome_depth(
            total_depth_df,
            gs=gridspec.GridSpecFromSubplotSpec(1, 1, subplot_spec=gs[1, 0]),
        )

        # Plot C: Number of Pores Available
        plot_flowcell_pores(
            flowcell_stats_df,
            gs=gridspec.GridSpecFromSubplotSpec(1, 1, subplot_spec=gs[1, 1]),
        )

        # Plot D: Correlation Plot (capture both figure and statistics)
        plot_flowcell_depth_correlation(
            merged_stats_df,
            gs=gridspec.GridSpecFromSubplotSpec(1, 1, subplot_spec=gs[2, 0]),
        )

        # Plot E: Multiplexed Flowcell Reads
        plot_multiplexed_flowcell_reads(
            seq_summaries_dir,
            gs=gridspec.GridSpecFromSubplotSpec(1, 1, subplot_spec=gs[2, 1]),
        )

        # Add panel labels
        for i, label in enumerate(["A", "B", "C", "D", "E"]):
            ax = fig.axes[i]
            ax.text(
                -0.1,
                1.05,
                label,
                transform=ax.transAxes,
                fontsize=12,
                fontweight="bold",
                va="top",
            )

        fig.set_constrained_layout(True)
        return fig

    except Exception as e:
        logger.error(f"Error creating combined sequencing plots: {str(e)}")
        raise


combined_sequencing_plot = create_combined_sequencing_plot(
    wg_depth_df=wg_depth_df,
    total_depth_df=total_depth_df,
    flowcell_stats_df=flowcell_stats_df,
    merged_stats_df=merged_stats_df,
    seq_summaries_dir=np_seq_summaries_dir,
)
Multiplex Flowcell 1 - Coefficient of Variation: 63.78%
Multiplex Flowcell 2 - Coefficient of Variation: 31.77%
Multiplex Flowcell 3 - Coefficient of Variation: 38.91%
Unclassified reads: 19.04% ± 3.15%
No description has been provided for this image

SNV Benchmark¶

Comparative analysis¶

1. Sensitivity, Precision, and F1¶

In [24]:
@dataclass
class SNVRTGMetrics:
    """Data class for RTG vcfeval metrics."""

    true_pos_baseline: int
    true_pos_call: int
    false_pos: int
    false_neg: int
    precision: float
    sensitivity: float
    f_measure: float


@dataclass
class SNVAnalysisConfig:
    """Configuration for RTG analysis."""

    technologies: tuple[str, ...] = ("ont", "illumina")
    complexities: tuple[str, ...] = ("hc", "lc")
    metrics_to_test: tuple[str, ...] = ("precision", "sensitivity", "f_measure")


@dataclass
class StatisticalResults:
    """Container for statistical test results."""

    t_statistic: float
    p_value: float
    adjusted_p_value: Optional[float] = None


def _read_rtg_summary(file_path: Path) -> Optional[Dict[str, float]]:
    """
    Read and parse RTG vcfeval summary file.

    Args:
        file_path: Path to the summary file

    Returns:
        Dictionary containing RTG metrics or None if parsing fails
    """
    try:
        with open(file_path, "r") as f:
            lines = f.readlines()

        none_line = next(
            (line for line in lines if line.strip().startswith("None")), None
        )

        if none_line:
            values = none_line.split()
            return SNVRTGMetrics(
                true_pos_baseline=int(values[1]),
                true_pos_call=int(values[2]),
                false_pos=int(values[3]),
                false_neg=int(values[4]),
                precision=float(values[5]),
                sensitivity=float(values[6]),
                f_measure=float(values[7]),
            ).__dict__

        logger.warning(f"No 'None' threshold line found in {file_path}")
        return None

    except FileNotFoundError:
        logger.error(f"File not found: {file_path}")
        return None
    except Exception as e:
        logger.error(f"Error reading file {file_path}: {str(e)}")
        return None


def _calculate_rtg_statistics(df: pl.DataFrame) -> pl.DataFrame:
    """
    Calculate statistics for RTG metrics grouped by complexity.

    Args:
        df: Input DataFrame containing RTG metrics

    Returns:
        DataFrame with calculated statistics
    """
    metrics = ["precision", "sensitivity", "f_measure"]
    stats_exprs = []

    for metric in metrics:
        stats_exprs.extend(
            [
                pl.col(metric).mean().alias(f"{metric}_mean"),
                pl.col(metric).std().alias(f"{metric}_std"),
                pl.col(metric).median().alias(f"{metric}_median"),
                pl.col(metric).min().alias(f"{metric}_min"),
                pl.col(metric).max().alias(f"{metric}_max"),
            ]
        )

    return df.group_by("complexity").agg(stats_exprs)


def collect_snv_rtg_metrics(
    sample_ids: pl.DataFrame, base_path: Path, config: SNVAnalysisConfig
) -> Dict[str, pl.DataFrame]:
    """
    Collect SNV metrics from RTG vcfeval summary files.

    Args:
        sample_ids: DataFrame containing sample IDs
        base_path: Base path to the summary files
        config: Analysis configuration

    Returns:
        Dictionary containing DataFrames with metrics for each technology
    """
    metrics_data = {tech: [] for tech in config.technologies}

    for row in sample_ids.iter_rows(named=True):
        for tech in config.technologies:
            for complexity in config.complexities:
                sample_id = row["ont_id"] if tech == "ont" else row["lp_id"]
                summary_file = (
                    base_path
                    / complexity
                    / "aggregate"
                    / f"{sample_id}.snv"
                    / "summary.txt"
                )

                if summary := _read_rtg_summary(summary_file):
                    metrics_entry = {
                        "sample_id": sample_id,
                        "complexity": complexity,
                        **summary,
                    }
                    metrics_data[tech].append(metrics_entry)
                else:
                    logger.warning(
                        f"Skipping empty summary for {sample_id}, {tech}, {complexity}"
                    )

    return {
        tech: pl.DataFrame(data)
        for tech, data in metrics_data.items()
        if data  # Only include non-empty data
    }


def _perform_ttest(
    ont_data: pl.DataFrame, illumina_data: pl.DataFrame, metric: str
) -> Tuple[float, float]:
    """
    Perform t-test between ONT and Illumina data for a given metric.

    Args:
        ont_data: DataFrame containing ONT metrics
        illumina_data: DataFrame containing Illumina metrics
        metric: Name of the metric to test

    Returns:
        Tuple containing t-statistic and p-value
    """
    ont_values = ont_data.get_column(metric).to_numpy()
    illumina_values = illumina_data.get_column(metric).to_numpy()
    return stats.ttest_ind(ont_values, illumina_values)


def run_rtg_statistical_analysis(
    rtg_metrics_dfs: Dict[str, pl.DataFrame], config: SNVAnalysisConfig
) -> Dict[str, pl.DataFrame]:
    """
    Run statistical analysis on RTG metrics.

    Args:
        rtg_metrics_dfs: Dictionary of DataFrames containing metrics for each technology
        config: Analysis configuration

    Returns:
        Dictionary containing statistical results for each technology
    """
    results: Dict[str, List[Dict]] = {
        complexity: [] for complexity in config.complexities
    }
    all_p_values = []

    for complexity in config.complexities:
        ont_data = rtg_metrics_dfs["ont"].filter(pl.col("complexity") == complexity)
        illumina_data = rtg_metrics_dfs["illumina"].filter(
            pl.col("complexity") == complexity
        )

        for metric in config.metrics_to_test:
            metric_lower = metric.lower()
            t_stat, p_value = _perform_ttest(ont_data, illumina_data, metric_lower)

            results[complexity].append(
                {"metric": metric, "t_statistic": t_stat, "p_value": p_value}
            )
            all_p_values.append(p_value)

    # FDR correction
    _, adjusted_p_values = multipletests(all_p_values, method="fdr_bh")[:2]

    p_value_idx = 0
    for complexity in results:
        for result in results[complexity]:
            result["adjusted_p_value"] = adjusted_p_values[p_value_idx]
            p_value_idx += 1

    # Convert results to DataFrames
    return {complexity: pl.DataFrame(data) for complexity, data in results.items()}


def display_combined_snv_rtg_statistics(
    rtg_metrics_dfs: Dict[str, pl.DataFrame],
    statistical_results: Dict[str, pl.DataFrame],
    config: SNVAnalysisConfig,
) -> Dict[str, Dict[str, pl.DataFrame]]:
    """
    Display combined statistics and statistical test results.

    Args:
        rtg_metrics_dfs: Dictionary of DataFrames containing metrics
        statistical_results: Dictionary of DataFrames containing statistical results
        config: Analysis configuration

    Returns:
        Dictionary containing ONT and Illumina statistics across complexities.
    """
    stats_data = {}

    print("\n### Sequencing Platform Statistics ###")

    for tech in config.technologies:
        print(f"\n--- {tech.upper()} ---")

        # Calculate and store statistics for both complexities together
        combined_stats = _calculate_rtg_statistics(rtg_metrics_dfs[tech])

        stats_data[tech] = combined_stats

        display(combined_stats)

        # Count total true positives, false negatives, and false positives per complexity
        variant_counts_df = (
            rtg_metrics_dfs[tech]
            .group_by("complexity")
            .agg(
                pl.col("true_pos_baseline").sum().alias("total_true_positives"),
                pl.col("false_neg").sum().alias("total_false_negatives"),
                pl.col("false_pos").sum().alias("total_false_positives"),
            )
            .with_columns(
                (
                    pl.col("total_true_positives")
                    + pl.col("total_false_negatives")
                    + pl.col("total_false_positives")
                ).alias("total_variants")
            )
        )

        print("Variant counts:")
        display(variant_counts_df)

    print("\n### Statistical Tests by Complexity ###")
    for complexity in config.complexities:
        print(f"\nResults for {complexity.upper()} regions:")
        display(statistical_results[complexity])

    return stats_data


sample_ids = pl.read_csv("sample_ids.csv")

snv_config = SNVAnalysisConfig()
base_path = Path("/scratch/prj/ppn_als_longread/ont-benchmark/output/snv/rtg_vcfeval")

snv_rtg_metrics_dfs = collect_snv_rtg_metrics(sample_ids, base_path, snv_config)

snv_rtg_statistical_results = run_rtg_statistical_analysis(
    snv_rtg_metrics_dfs, snv_config
)

snv_rtg_statistics = display_combined_snv_rtg_statistics(
    snv_rtg_metrics_dfs, snv_rtg_statistical_results, snv_config
)

# Create combined DataFrames containing data for all complexities
snv_ont_stats = snv_rtg_statistics["ont"].with_columns(
    pl.when(pl.col("complexity").is_null())
    .then(pl.lit("unknown"))
    .otherwise(pl.col("complexity"))
    .alias("complexity")
)

snv_illumina_stats = snv_rtg_statistics["illumina"].with_columns(
    pl.when(pl.col("complexity").is_null())
    .then(pl.lit("unknown"))
    .otherwise(pl.col("complexity"))
    .alias("complexity")
)
### Sequencing Platform Statistics ###

--- ONT ---
shape: (2, 16)
complexityprecision_meanprecision_stdprecision_medianprecision_minprecision_maxsensitivity_meansensitivity_stdsensitivity_mediansensitivity_minsensitivity_maxf_measure_meanf_measure_stdf_measure_medianf_measure_minf_measure_max
strf64f64f64f64f64f64f64f64f64f64f64f64f64f64f64
"lc"0.7779570.0133310.780750.74860.79360.7342210.0186780.741150.69450.75470.7554290.0159670.761150.72050.7723
"hc"0.9530790.0098380.95510.92860.96250.9555640.0212290.96490.90850.97490.9542930.0154010.96050.91840.9686
Variant counts:
shape: (2, 5)
complexitytotal_true_positivestotal_false_negativestotal_false_positivestotal_variants
stri64i64i64i64
"lc"29850210803285131491665
"hc"87205424050664280019553609
--- ILLUMINA ---
shape: (2, 16)
complexityprecision_meanprecision_stdprecision_medianprecision_minprecision_maxsensitivity_meansensitivity_stdsensitivity_mediansensitivity_minsensitivity_maxf_measure_meanf_measure_stdf_measure_medianf_measure_minf_measure_max
strf64f64f64f64f64f64f64f64f64f64f64f64f64f64f64
"hc"0.9629430.0035520.964950.95450.96630.9724710.0008510.972850.97060.97360.9676930.0021140.96890.96250.9696
"lc"0.799450.0048080.80090.78820.80460.74330.0022830.74320.73880.74630.7703360.0030140.770650.76270.7738
Variant counts:
shape: (2, 5)
complexitytotal_true_positivestotal_false_negativestotal_false_positivestotal_variants
stri64i64i64i64
"lc"30217410436075798482332
"hc"88744802511283413529466960
### Statistical Tests by Complexity ###

Results for HC regions:
shape: (3, 4)
metrict_statisticp_valueadjusted_p_value
strf64f64f64
"precision"-3.5288530.0015760.004023
"sensitivity"-2.9775880.0062140.007457
"f_measure"-3.2252010.0033840.005076
Results for LC regions:
shape: (3, 4)
metrict_statisticp_valueadjusted_p_value
strf64f64f64
"precision"-5.6747130.0000060.000034
"sensitivity"-1.8052180.0826350.082635
"f_measure"-3.4327030.0020120.004023
In [25]:
def _get_significance_level(p_value: float) -> str:
    """
    Determine significance level based on p-value.

    Args:
        p_value: Statistical p-value

    Returns:
        str: Significance level indicator
    """
    if p_value < 0.001:
        return "***"
    elif p_value < 0.01:
        return "**"
    elif p_value < 0.05:
        return "*"
    return ""


def prepare_snv_performance_data(
    ont_stats: pl.DataFrame,
    illumina_stats: pl.DataFrame,
    stat_results: Dict[str, pl.DataFrame],
    metrics: Tuple[str, ...],
    complexities: Tuple[str, ...],
) -> pl.DataFrame:
    """
    Prepare data for performance visualization.

    Args:
        ont_stats: ONT statistics DataFrame
        illumina_stats: Illumina statistics DataFrame
        stat_results: Dictionary containing statistical test results
        metrics: Metrics to plot
        complexities: Complexity levels to plot

    Returns:
        pl.DataFrame: Prepared data for plotting
    """
    plot_data = []
    # Map plotting metric names to the underlying column name keys.
    metric_mapping = {
        "Precision": "precision",
        "Sensitivity": "sensitivity",
        "F-measure": "f_measure",
    }
    for complexity in complexities:
        for metric in metrics:
            try:
                col_name = f"{metric_mapping[metric]}_mean"

                # Check if we have data for this complexity
                ont_filtered = ont_stats.filter(pl.col("complexity") == complexity)
                illumina_filtered = illumina_stats.filter(
                    pl.col("complexity") == complexity
                )

                if ont_filtered.height == 0 or illumina_filtered.height == 0:
                    logger.warning(f"No data available for {complexity} {metric}")
                    continue

                ont_value = ont_filtered.get_column(col_name)[0]
                illumina_value = illumina_filtered.get_column(col_name)[0]

                # Get adjusted p-value from the statistical results
                stat_df = stat_results.get(complexity)
                if stat_df is not None:
                    stat_rows = stat_df.filter(
                        pl.col("metric") == metric_mapping[metric]
                    ).to_dicts()
                    adjusted_p_value = (
                        stat_rows[0]["adjusted_p_value"] if stat_rows else None
                    )
                else:
                    adjusted_p_value = None

                significance = (
                    _get_significance_level(adjusted_p_value)
                    if adjusted_p_value is not None
                    else ""
                )
                plot_data.extend(
                    [
                        {
                            "Complexity": complexity.upper(),
                            "Metric": metric,
                            "Technology": "long-read",
                            "Value": ont_value,
                            "Significance": significance,
                        },
                        {
                            "Complexity": complexity.upper(),
                            "Metric": metric,
                            "Technology": "short-read",
                            "Value": illumina_value,
                            "Significance": significance,
                        },
                    ]
                )
            except Exception as e:
                logger.error(
                    f"Error preparing data for {complexity} {metric}: {str(e)}"
                )
                continue

    return pl.DataFrame(plot_data)


def plot_snv_performance_metrics(
    plot_data_df: pl.DataFrame,
    figsize: Tuple[int, int] = (14, 6),
    dpi: int = 300,
    ylim: Tuple[float, float] = (0, 1),
    title: str = "Performance Comparison",
    metrics: Tuple[str, ...] = ("Precision", "Sensitivity", "F-measure"),
    gs: Optional[gridspec.GridSpec] = None,
) -> Optional[Tuple[plt.Figure, np.ndarray]]:
    """
    Create a performance comparison plot for SNV detection between long-read and short-read technologies.

    This function creates a side-by-side comparison of performance metrics for high and low complexity
    regions. It automatically generates two subplots (high and low complexity) with bar plots showing
    the specified performance metrics for both technologies.

    Args:
        plot_data_df (pl.DataFrame): Prepared data DataFrame from prepare_snv_performance_data
        figsize (Tuple[int, int], optional): Figure size in inches. Defaults to (14, 6)
        dpi (int, optional): Figure resolution. Defaults to 300
        ylim (Tuple[float, float], optional): Y-axis limits. Defaults to (0, 1.05)
        title (str, optional): Plot title. Defaults to "Performance Comparison"
        metrics (Tuple[str, ...], optional): Metrics to plot. Defaults to ("Precision", "Sensitivity", "F-measure")
        gs (Optional[gridspec.GridSpec], optional): GridSpec for subplot placement. Defaults to None

    Returns:
        Optional[Tuple[plt.Figure, np.ndarray]]: Figure and axes objects if gs is None, None otherwise

    Raises:
        Exception: If there's an error in creating the performance plots
    """
    try:
        if gs is None:
            fig, axes = plt.subplots(1, 2, figsize=figsize, dpi=dpi)
        else:
            fig = plt.gcf()
            axes = [fig.add_subplot(gs[0, i]) for i in range(2)]

        complexities = ["HC", "LC"]
        for i, complexity in enumerate(complexities):
            complexity_data = plot_data_df.filter(pl.col("Complexity") == complexity)

            bars = sns.barplot(
                data=complexity_data,
                x="Metric",
                y="Value",
                hue="Technology",
                errorbar=None,
                ax=axes[i],
            )

            # Add value labels on top of each bar
            for p in bars.patches:
                value = p.get_height()
                if value > 0:  # Only annotate if value is not 0
                    axes[i].annotate(
                        f"{value:.3f}",
                        (p.get_x() + p.get_width() / 2.0, value),
                        ha="center",
                        va="bottom",
                        fontsize=8,
                        rotation=0,
                    )

            complexity_label = "High" if complexity == "HC" else "Low"
            axes[i].set_title(
                f"SNV Performance in {complexity_label} Complexity Regions", pad=15
            )
            axes[i].set_ylim(ylim)

            # Add significance annotations
            for metric_idx, metric in enumerate(metrics):
                metric_data = complexity_data.filter(pl.col("Metric") == metric)
                if metric_data.height > 0:
                    significance = metric_data.get_column("Significance")[0]
                    if significance:
                        y = metric_data.get_column("Value").max() + 0.02
                        axes[i].text(
                            metric_idx,
                            y,
                            significance,
                            ha="center",
                            va="bottom",
                            color="black",
                            fontweight="bold",
                        )

            axes[i].set_xlabel("")
            axes[i].set_ylabel("Performance")

            if i == 0:
                axes[i].legend_.remove()
            else:
                if gs is None:
                    legend = axes[i].legend(
                        title="Technology", bbox_to_anchor=(1, 1), loc="upper left"
                    )
                    legend.get_title().set_weight("bold")
                else:
                    legend = axes[i].legend(title="Technology", loc="lower right")
                    legend.get_title().set_weight("bold")

        if gs is None:
            plt.tight_layout()
            return fig
        return None

    except Exception as e:
        logger.error(f"Error creating performance plots: {str(e)}")
        raise


performance_data = prepare_snv_performance_data(
    ont_stats=snv_ont_stats,
    illumina_stats=snv_illumina_stats,
    stat_results=snv_rtg_statistical_results,
    metrics=("Precision", "Sensitivity", "F-measure"),
    complexities=("hc", "lc"),
)

snv_performance_plot = plot_snv_performance_metrics(
    plot_data_df=performance_data,
)
No description has been provided for this image

2. Error Analysis¶

In [26]:
def _get_vcfeval_snv_error_paths(
    sample_id: str, tech: str, complexity: str, base_dir: Path = base_path
) -> Tuple[Path, Path, Path]:
    """
    Generate paths for VCF evaluation files.

    Args:
        sample_id: Sample identifier
        tech: Technology type (ont/illumina)
        complexity: Genomic complexity region (hc/lc)
        base_dir: Base directory for project data

    Returns:
        Tuple of Paths for (false positives, false negatives, query) VCF files
    """
    vcfeval_dir = base_dir / complexity / "aggregate" / f"{sample_id}.snv"
    return (
        vcfeval_dir / "fp.vcf.gz",
        vcfeval_dir / "fn.vcf.gz",
        vcfeval_dir / "tp.vcf.gz",
    )


def _count_snv_types(vcf_file: Path) -> Dict[str, int]:
    """
    Count different types of SNVs in a VCF file.

    Args:
        vcf_file: Path to VCF file

    Returns:
        Dictionary mapping SNV types to their counts
    """
    snv_counts: Dict[str, int] = {}
    try:
        with pysam.VariantFile(str(vcf_file)) as vcf:
            for record in vcf:
                ref = record.ref
                alt = record.alts[0]
                if len(ref) == 1 and len(alt) == 1:
                    snv_type = f"{ref}>{alt}"
                    snv_counts[snv_type] = snv_counts.get(snv_type, 0) + 1
    except Exception as e:
        logger.error(f"Error counting SNV types in {vcf_file}: {str(e)}")
    return snv_counts


def count_total_variants(vcf_file: Path) -> int:
    """
    Count total number of variants in a VCF file.

    Args:
        vcf_file: Path to VCF file

    Returns:
        Total number of variants
    """
    try:
        with pysam.VariantFile(str(vcf_file)) as vcf:
            return sum(1 for _ in vcf)
    except Exception as e:
        logger.error(f"Error counting variants in {vcf_file}: {str(e)}")
        return 0


def calculate_snv_error_rates(
    sample_ids: pl.DataFrame,
    technologies: List[str] = ["ont", "illumina"],
    complexities: List[str] = ["hc", "lc"],
    base_dir: Path = base_path,
) -> Tuple[
    Dict[str, Dict[str, Dict[str, Dict[str, List[float]]]]], Dict[str, Dict[str, int]]
]:
    """
    Calculate SNV error rates across samples and technologies.

    Args:
        sample_ids: DataFrame containing sample ID mappings
        technologies: List of sequencing technologies
        complexities: List of genomic complexity regions
        base_dir: Base directory for project data

    Returns:
        Tuple containing:
        - Nested dictionary of error rates
        - Dictionary of sample counts per technology and complexity
    """
    snv_types = [f"{ref}>{alt}" for ref in "ACGT" for alt in "ACGT" if ref != alt]
    snv_error_rates = {
        tech: {comp: {"FP": {}, "FN": {}} for comp in complexities}
        for tech in technologies
    }
    sample_counts = {tech: {comp: 0 for comp in complexities} for tech in technologies}

    for row in sample_ids.iter_rows(named=True):
        for tech in technologies:
            sample_id = row["ont_id"] if tech == "ont" else row["lp_id"]

            for complexity in complexities:
                fp_vcf, fn_vcf, query_vcf = _get_vcfeval_snv_error_paths(
                    sample_id, tech, complexity, base_dir
                )

                if not all(path.exists() for path in [fp_vcf, fn_vcf, query_vcf]):
                    logger.warning(
                        f"VCF files not found for {sample_id}, {tech}, {complexity}"
                    )
                    continue

                sample_counts[tech][complexity] += 1
                total_variants = count_total_variants(query_vcf)

                if total_variants == 0:
                    logger.warning(
                        f"No variants found for {sample_id}, {tech}, {complexity}"
                    )
                    continue

                for error_type, vcf_path in [("FP", fp_vcf), ("FN", fn_vcf)]:
                    counts = _count_snv_types(vcf_path)

                    for snv_type in snv_types:
                        if (
                            snv_type
                            not in snv_error_rates[tech][complexity][error_type]
                        ):
                            snv_error_rates[tech][complexity][error_type][snv_type] = []

                        error_rate = counts.get(snv_type, 0) / total_variants
                        snv_error_rates[tech][complexity][error_type][snv_type].append(
                            error_rate
                        )

    return snv_error_rates, sample_counts


def prepare_snv_error_data(
    snv_error_rates: Dict[str, Dict[str, Dict[str, Dict[str, List[float]]]]]
) -> pl.DataFrame:
    """
    Prepare SNV error rate data for visualization.

    Args:
        snv_error_rates: Nested dictionary containing error rates

    Returns:
        DataFrame containing processed error rate data
    """
    tech_mapping = {"ont": "long-read", "illumina": "short-read"}
    plot_data = []

    for tech, tech_data in snv_error_rates.items():
        display_tech = tech_mapping[tech]
        for complexity, comp_data in tech_data.items():
            for error_type, error_data in comp_data.items():
                for snv_type, rates in error_data.items():
                    plot_data.append(
                        {
                            "Technology": display_tech,
                            "Complexity": complexity.upper(),
                            "Error_Type": error_type,
                            "SNV_Type": snv_type,
                            "Error_Rate": np.mean(rates) if rates else 0.0,
                        }
                    )
    return pl.DataFrame(plot_data)


def perform_statistical_tests(
    snv_error_rates: Dict[str, Dict[str, Dict[str, Dict[str, List[float]]]]],
    sample_counts: Dict[str, Dict[str, int]],
) -> pl.DataFrame:
    """
    Perform statistical tests comparing error rates between technologies.

    Args:
        snv_error_rates: Nested dictionary containing error rates
        sample_counts: Dictionary of sample counts

    Returns:
        DataFrame containing statistical test results
    """
    results = []
    p_values = []

    for complexity in ["hc", "lc"]:
        for error_type in ["FP", "FN"]:
            for snv_type in snv_error_rates["ont"][complexity][error_type]:
                long_read_rates = snv_error_rates["ont"][complexity][error_type][
                    snv_type
                ]
                short_read_rates = snv_error_rates["illumina"][complexity][error_type][
                    snv_type
                ]

                if not long_read_rates or not short_read_rates:
                    logger.warning(
                        f"No data for {complexity}, {error_type}, {snv_type}"
                    )
                    continue

                n = min(len(long_read_rates), len(short_read_rates))
                try:
                    t_stat, p_val = stats.ttest_rel(
                        long_read_rates[:n], short_read_rates[:n]
                    )
                except Exception as e:
                    logger.error(
                        f"Error in t-test for {complexity}, {error_type}, {snv_type}: {str(e)}"
                    )
                    continue

                results.append(
                    {
                        "Complexity": complexity.upper(),
                        "Error_Type": error_type,
                        "SNV_Type": snv_type,
                        "t_statistic": t_stat,
                        "p_value": p_val,
                        "n": n,
                    }
                )
                p_values.append(p_val)

    if not results:
        logger.warning("No statistical test results could be calculated")
        return pl.DataFrame()

    try:
        rejected, p_corrected, _, _ = multipletests(p_values, method="fdr_bh")
    except Exception as e:
        logger.error(f"Error in multiple testing correction: {str(e)}")
        return pl.DataFrame(results)

    for i, (result, p_adj, is_rejected) in enumerate(
        zip(results, p_corrected, rejected)
    ):
        result.update(
            {
                "p_value_adjusted": p_adj,
                "significance": _get_significance_level(p_adj),
                "rejected": is_rejected,
            }
        )

    return pl.DataFrame(results)


def plot_snv_error_rates(
    plot_data: pl.DataFrame,
    statistical_results: pl.DataFrame,
    figsize: Tuple[int, int] = (16, 12),
    dpi: int = 300,
    gs: Optional[gridspec.GridSpec] = None,
    ylim: Optional[Tuple[float, float]] = None,
) -> Optional[Tuple[plt.Figure, np.ndarray]]:
    """
    Create plots showing SNV error rates across technologies and complexities.

    Args:
        plot_data: DataFrame containing error rate data
        statistical_results: DataFrame containing statistical test results
        figsize: Figure dimensions as (width, height)
        dpi: Figure resolution
        gs: Optional GridSpec for subplot placement
        ylim: Optional Y-axis limits as (min, max)

    Returns:
        - (fig, axes) if gs is None
        - None if using a provided GridSpec
    """
    if gs is None:
        fig, axes = plt.subplots(2, 2, figsize=figsize, dpi=dpi)
    else:
        fig = plt.gcf()
        axes = np.array(
            [[fig.add_subplot(gs[i, j]) for j in range(2)] for i in range(2)]
        )

    try:
        complexities = sorted(plot_data["Complexity"].unique().to_list())
        error_types = sorted(plot_data["Error_Type"].unique().to_list())

        for i, complexity in enumerate(complexities):
            for j, error_type in enumerate(error_types):
                subset = plot_data.filter(
                    (pl.col("Complexity") == complexity)
                    & (pl.col("Error_Type") == error_type)
                ).to_pandas()

                if subset.empty:
                    logger.warning(f"No data for {complexity}, {error_type}")
                    continue

                sns.barplot(
                    data=subset,
                    x="SNV_Type",
                    y="Error_Rate",
                    hue="Technology",
                    ax=axes[i, j],
                )

                axes[i, j].get_legend().remove()

                # Set title based on complexity
                title_complexity = "High" if "HC" in complexity else "Low"
                axes[i, j].set_title(
                    f"{error_type} Rates in {title_complexity} Complexity Regions "
                )
                axes[i, j].set_xlabel("SNV Type")
                axes[i, j].set_ylabel("Error Rate (%)")
                axes[i, j].yaxis.set_major_formatter(mticker.PercentFormatter(1))

                if ylim:
                    axes[i, j].set_ylim(ylim)

                # Add significance annotations
                for idx, snv_type in enumerate(subset["SNV_Type"].unique()):
                    filtered_results = statistical_results.filter(
                        (pl.col("Complexity") == complexity)
                        & (pl.col("Error_Type") == error_type)
                        & (pl.col("SNV_Type") == snv_type)
                    )

                    if not filtered_results.is_empty():
                        max_height = subset[subset["SNV_Type"] == snv_type][
                            "Error_Rate"
                        ].max()
                        significance = filtered_results.select(
                            pl.col("significance")
                        ).item()
                        axes[i, j].text(
                            idx,
                            max_height,
                            significance,
                            ha="center",
                            va="bottom",
                            fontweight="bold",
                        )

                if (i, j) == (0, 1) and gs is None:
                    axes[i, j].legend(
                        title="Technology", bbox_to_anchor=(1, 1), loc="upper left"
                    )
                    axes[i, j].get_legend().get_title().set_weight("bold")

    except Exception as e:
        logger.error(f"Error creating SNV error rate plots: {str(e)}")
        raise

    if gs is None:
        plt.tight_layout()
        return fig
    return None


snv_error_rates, sample_counts = calculate_snv_error_rates(
    sample_ids,
    technologies=snv_config.technologies,
    complexities=snv_config.complexities,
)

snv_error_plot_data = prepare_snv_error_data(snv_error_rates)

snv_error_statistical_results = perform_statistical_tests(
    snv_error_rates, sample_counts
)

snv_error_rate_plot = plot_snv_error_rates(
    snv_error_plot_data, snv_error_statistical_results
)

with pl.Config(tbl_rows=len(snv_error_statistical_results)):
    display(snv_error_statistical_results)
[W::hts_idx_load3] The index file is older than the data file: /scratch/prj/ppn_als_longread/ont-benchmark/output/snv/rtg_vcfeval/hc/aggregate/LP6008462-DNA_A09.snv/tp.vcf.gz.tbi
[W::hts_idx_load3] The index file is older than the data file: /scratch/prj/ppn_als_longread/ont-benchmark/output/snv/rtg_vcfeval/hc/aggregate/A048_09.snv/tp.vcf.gz.tbi
[W::hts_idx_load3] The index file is older than the data file: /scratch/prj/ppn_als_longread/ont-benchmark/output/snv/rtg_vcfeval/hc/aggregate/LP6008463-DNA_F04.snv/tp.vcf.gz.tbi
[W::hts_idx_load3] The index file is older than the data file: /scratch/prj/ppn_als_longread/ont-benchmark/output/snv/rtg_vcfeval/hc/aggregate/A079_07.snv/tp.vcf.gz.tbi
[W::hts_idx_load3] The index file is older than the data file: /scratch/prj/ppn_als_longread/ont-benchmark/output/snv/rtg_vcfeval/hc/aggregate/LP6008463-DNA_F02.snv/tp.vcf.gz.tbi
[W::hts_idx_load3] The index file is older than the data file: /scratch/prj/ppn_als_longread/ont-benchmark/output/snv/rtg_vcfeval/hc/aggregate/A081_91.snv/fp.vcf.gz.tbi
[W::hts_idx_load3] The index file is older than the data file: /scratch/prj/ppn_als_longread/ont-benchmark/output/snv/rtg_vcfeval/hc/aggregate/LP6008463-DNA_H09.snv/tp.vcf.gz.tbi
[W::hts_idx_load3] The index file is older than the data file: /scratch/prj/ppn_als_longread/ont-benchmark/output/snv/rtg_vcfeval/hc/aggregate/LP6008463-DNA_A07.snv/tp.vcf.gz.tbi
[W::hts_idx_load3] The index file is older than the data file: /scratch/prj/ppn_als_longread/ont-benchmark/output/snv/rtg_vcfeval/hc/aggregate/LP6008463-DNA_A07.snv/fn.vcf.gz.tbi
[W::hts_idx_load3] The index file is older than the data file: /scratch/prj/ppn_als_longread/ont-benchmark/output/snv/rtg_vcfeval/lc/aggregate/A149_01.snv/tp.vcf.gz.tbi
[W::hts_idx_load3] The index file is older than the data file: /scratch/prj/ppn_als_longread/ont-benchmark/output/snv/rtg_vcfeval/hc/aggregate/LP6008462-DNA_C03.snv/tp.vcf.gz.tbi
[W::hts_idx_load3] The index file is older than the data file: /scratch/prj/ppn_als_longread/ont-benchmark/output/snv/rtg_vcfeval/hc/aggregate/LP6008462-DNA_D03.snv/tp.vcf.gz.tbi
[W::hts_idx_load3] The index file is older than the data file: /scratch/prj/ppn_als_longread/ont-benchmark/output/snv/rtg_vcfeval/hc/aggregate/LP6008463-DNA_F01.snv/tp.vcf.gz.tbi
[W::hts_idx_load3] The index file is older than the data file: /scratch/prj/ppn_als_longread/ont-benchmark/output/snv/rtg_vcfeval/hc/aggregate/LP6008462-DNA_C05.snv/tp.vcf.gz.tbi
[W::hts_idx_load3] The index file is older than the data file: /scratch/prj/ppn_als_longread/ont-benchmark/output/snv/rtg_vcfeval/hc/aggregate/A160_96.snv/tp.vcf.gz.tbi
[W::hts_idx_load3] The index file is older than the data file: /scratch/prj/ppn_als_longread/ont-benchmark/output/snv/rtg_vcfeval/hc/aggregate/A162_09.snv/tp.vcf.gz.tbi
shape: (48, 9)
ComplexityError_TypeSNV_Typet_statisticp_valuenp_value_adjustedsignificancerejected
strstrstrf64f64i64f64stri64
"HC""FP""A>C"4.7968010.000349140.00082"***"1
"HC""FP""A>G"5.049520.000223140.000708"***"1
"HC""FP""A>T"5.016350.000236140.000708"***"1
"HC""FP""C>A"4.6674720.00044140.000919"***"1
"HC""FP""C>G"5.4173160.000118140.00047"***"1
"HC""FP""C>T"4.8977770.000291140.000736"***"1
"HC""FP""G>A"4.781260.000359140.00082"***"1
"HC""FP""G>C"5.1694530.00018140.000619"***"1
"HC""FP""G>T"4.7539910.000377140.000822"***"1
"HC""FP""T>A"4.2961750.000869140.001739"**"1
"HC""FP""T>C"4.9479390.000266140.000711"***"1
"HC""FP""T>G"4.947740.000267140.000711"***"1
"HC""FN""A>C"2.9077510.012225140.016969"*"1
"HC""FN""A>G"2.9449840.011381140.016969"*"1
"HC""FN""A>T"2.7559240.016352140.020125"*"1
"HC""FN""C>A"2.9277630.011764140.016969"*"1
"HC""FN""C>G"3.3210260.00552140.009463"**"1
"HC""FN""C>T"3.084970.008695140.014391"*"1
"HC""FN""G>A"2.9014550.012373140.016969"*"1
"HC""FN""G>C"3.3241180.005487140.009463"**"1
"HC""FN""G>T"2.8114920.014703140.019604"*"1
"HC""FN""T>A"2.7724840.015842140.020125"*"1
"HC""FN""T>C"3.0037370.010166140.016265"*"1
"HC""FN""T>G"2.9660150.01093140.016924"*"1
"LC""FP""A>C"11.5345033.3539e-8140.000002"***"1
"LC""FP""A>G"6.4097370.000023140.000111"***"1
"LC""FP""A>T"8.2594620.000002140.000015"***"1
"LC""FP""C>A"8.8838316.9737e-7140.00001"***"1
"LC""FP""C>G"5.2418830.000159140.000587"***"1
"LC""FP""C>T"7.9085650.000003140.000017"***"1
"LC""FP""G>A"7.3262550.000006140.000031"***"1
"LC""FP""G>C"8.7786367.9751e-7140.00001"***"1
"LC""FP""G>T"7.9063640.000003140.000017"***"1
"LC""FP""T>A"9.2389334.4716e-7140.00001"***"1
"LC""FP""T>C"5.6871150.000075140.000325"***"1
"LC""FP""T>G"7.5023730.000004140.000027"***"1
"LC""FN""A>C"4.0567680.001359140.002609"**"1
"LC""FN""A>G"1.1835480.257783140.268991""0
"LC""FN""A>T"1.1582380.267603140.273297""0
"LC""FN""C>A"2.7618420.016168140.020125"*"1
"LC""FN""C>G"0.0659880.948392140.948392""0
"LC""FN""C>T"1.7498790.103689140.115746""0
"LC""FN""G>A"1.3437790.202005140.215472""0
"LC""FN""G>C"2.3649250.034259140.041111"*"1
"LC""FN""G>T"3.5744850.003394140.006265"**"1
"LC""FN""T>A"1.8267920.090773140.103741""0
"LC""FN""T>C"1.6694650.118915140.129725""0
"LC""FN""T>G"1.8578030.085987140.100668""0
No description has been provided for this image

3. Combined Plots¶

In [27]:
def create_combined_snv_metrics_plot(
    ont_stats: pl.DataFrame,
    illumina_stats: pl.DataFrame,
    stat_results: Dict[str, pl.DataFrame],
    error_plot_data: pl.DataFrame,
    error_statistical_results: pl.DataFrame,
    figsize: Tuple[int, int] = (12, 10),
    dpi: int = 300,
) -> plt.Figure:
    """
    Create a combined figure showing SNV performance metrics and error rates.

    Args:
        ont_stats: ONT statistics DataFrame
        illumina_stats: Illumina statistics DataFrame
        stat_results: Dictionary containing statistical test results
        error_plot_data: DataFrame containing error rate data
        error_statistical_results: DataFrame containing error rate statistical results
        figsize: Figure dimensions as (width, height)
        dpi: Figure resolution

    Returns:
        Combined figure object containing all plots

    Raises:
        Exception: If there's an error creating the combined plot
    """
    try:
        # Create figure with GridSpec
        fig = plt.figure(figsize=figsize, dpi=dpi)
        gs = fig.add_gridspec(3, 2, height_ratios=[1, 1, 1])

        # Prepare data for performance metrics plot
        performance_data = prepare_snv_performance_data(
            ont_stats=ont_stats,
            illumina_stats=illumina_stats,
            stat_results=stat_results,
            metrics=("Precision", "Sensitivity", "F-measure"),
            complexities=("hc", "lc"),
        )

        # Plot A & B: SNV Performance Metrics
        plot_snv_performance_metrics(
            plot_data_df=performance_data,
            gs=gridspec.GridSpecFromSubplotSpec(1, 2, subplot_spec=gs[0, :]),
        )

        # Plot C, D, E, F: Error Rates
        plot_snv_error_rates(
            plot_data=error_plot_data,
            statistical_results=error_statistical_results,
            gs=gridspec.GridSpecFromSubplotSpec(2, 2, subplot_spec=gs[1:, :]),
        )

        # Add panel labels
        for i, ax in enumerate(fig.axes):
            label = chr(ord("A") + i)
            ax.text(
                -0.12,
                1.05,
                label,
                transform=ax.transAxes,
                fontsize=12,
                fontweight="bold",
                va="top",
            )

        fig.set_constrained_layout(True)
        return fig

    except Exception as e:
        logger.error(f"Error creating combined SNV metrics plot: {str(e)}")
        raise


combined_snv_metrics_plot = create_combined_snv_metrics_plot(
    ont_stats=snv_ont_stats,
    illumina_stats=snv_illumina_stats,
    stat_results=snv_rtg_statistical_results,
    error_plot_data=snv_error_plot_data,
    error_statistical_results=snv_error_statistical_results,
)
No description has been provided for this image

Indel Benchmark¶

Comparative analysis¶

1. Sensitivity, Precision, and F1¶

In [28]:
@dataclass
class IndelRTGMetrics:
    """Data class for indel RTG vcfeval metrics."""

    true_pos_baseline: int
    true_pos_call: int
    false_pos: int
    false_neg: int
    precision: float
    sensitivity: float
    f_measure: float


@dataclass
class IndelAnalysisConfig:
    """Configuration for indel analysis."""

    technologies: tuple[str, ...] = ("ont",)
    complexities: tuple[str, ...] = ("hc", "lc")
    metrics_to_test: tuple[str, ...] = ("precision", "sensitivity", "f_measure")
    base_dir: Path = Path("/scratch/prj/ppn_als_longread/ont-benchmark")
    subdir_mapping: dict[str, str] = field(
        default_factory=lambda: {"hc": "aggregate", "lc": "1bp"}
    )


def _get_vcfeval_indel_paths(
    sample_id: str, complexity: str, config: IndelAnalysisConfig
) -> tuple[Path, Path, Path]:
    """
    Generate paths for indel VCF evaluation files.

    Args:
        sample_id: Sample identifier
        complexity: Genomic complexity region
        config: Analysis configuration

    Returns:
        Tuple of Paths for (false positives, false negatives, true positives) VCF files
    """
    subdir = config.subdir_mapping.get(complexity, "")
    vcfeval_dir = (
        config.base_dir
        / "output"
        / "indel"
        / "rtg_vcfeval"
        / complexity
        / subdir
        / f"{sample_id}.indel"
    )
    return (
        vcfeval_dir / "fp.vcf.gz",
        vcfeval_dir / "fn.vcf.gz",
        vcfeval_dir / "tp.vcf.gz",
    )


def collect_indel_rtg_metrics(
    sample_ids: pl.DataFrame, config: IndelAnalysisConfig
) -> Dict[str, pl.DataFrame]:
    """
    Collect indel metrics from RTG vcfeval summary files.

    Args:
        sample_ids: DataFrame containing sample IDs
        config: Analysis configuration

    Returns:
        Dictionary containing DataFrames with metrics for each technology
    """
    metrics_data = {tech: [] for tech in config.technologies}

    for row in sample_ids.iter_rows(named=True):
        for tech in config.technologies:
            for complexity in config.complexities:
                sample_id = row["ont_id"]
                fp_vcf, fn_vcf, tp_vcf = _get_vcfeval_indel_paths(
                    sample_id, complexity, config
                )

                if not all(path.exists() for path in [fp_vcf, fn_vcf, tp_vcf]):
                    logger.warning(f"VCF files not found for {sample_id}, {complexity}")
                    continue

                if summary := _read_rtg_summary(fp_vcf.parent / "summary.txt"):
                    metrics_entry = {
                        "sample_id": sample_id,
                        "complexity": complexity,
                        **summary,
                    }
                    metrics_data[tech].append(metrics_entry)
                else:
                    logger.warning(
                        f"Skipping empty summary for {sample_id}, {tech}, {complexity}"
                    )

    return {tech: pl.DataFrame(data) for tech, data in metrics_data.items() if data}


def display_indel_statistics(
    rtg_metrics_dfs: Dict[str, pl.DataFrame],
    config: IndelAnalysisConfig,
) -> Dict[str, Dict[str, pl.DataFrame]]:
    """
    Process and compile statistics for indel analysis.

    Args:
        rtg_metrics_dfs: Dictionary of DataFrames containing metrics
        config: Analysis configuration

    Returns:
        Dictionary containing ONT statistics for each complexity
    """
    stats_data = {}

    for complexity in config.complexities:
        ont_stats = _calculate_rtg_statistics(
            rtg_metrics_dfs["ont"].filter(pl.col("complexity") == complexity)
        )

        stats_data[complexity] = {"ont": ont_stats}

    return stats_data


indel_config = IndelAnalysisConfig()

indel_rtg_metrics_dfs = collect_indel_rtg_metrics(sample_ids, indel_config)

indel_statistics = display_indel_statistics(indel_rtg_metrics_dfs, indel_config)

# Create combined DataFrame containing data for all complexities
indel_ont_stats = pl.concat(
    [
        indel_statistics[complexity]["ont"].with_columns(
            pl.lit(complexity).alias("complexity")
        )
        for complexity in indel_config.complexities
    ]
)

print("ONT Indel Statistics:")
display(indel_ont_stats)
ONT Indel Statistics:
shape: (2, 16)
complexityprecision_meanprecision_stdprecision_medianprecision_minprecision_maxsensitivity_meansensitivity_stdsensitivity_mediansensitivity_minsensitivity_maxf_measure_meanf_measure_stdf_measure_medianf_measure_minf_measure_max
strf64f64f64f64f64f64f64f64f64f64f64f64f64f64f64
"hc"0.7864860.0668360.80390.67390.86840.9053930.0350580.919650.83080.94020.8412860.0531170.857850.74760.9026
"lc"0.3945860.094140.38390.28790.55930.4615430.039510.469950.39360.51160.4230860.069250.42350.33380.5344
In [29]:
def prepare_indel_performance_data(
    ont_stats: pl.DataFrame,
    metrics: Tuple[str, ...],
    complexities: Tuple[str, ...],
) -> pl.DataFrame:
    """
    Prepare data for indel performance visualization.

    Args:
        ont_stats: ONT statistics DataFrame
        metrics: Metrics to plot
        complexities: Complexity levels to plot

    Returns:
        pl.DataFrame: Prepared data for plotting
    """
    plot_data = []
    # Map plotting metric names to the underlying column name keys
    metric_mapping = {
        "Precision": "precision",
        "Sensitivity": "sensitivity",
        "F-measure": "f_measure",
    }

    for complexity in complexities:
        for metric in metrics:
            try:
                col_name = f"{metric_mapping[metric]}_mean"

                # Check if we have data for this complexity
                ont_filtered = ont_stats.filter(pl.col("complexity") == complexity)

                if ont_filtered.height == 0:
                    logger.warning(f"No data available for {complexity} {metric}")
                    continue

                ont_value = ont_filtered.get_column(col_name)[0]

                plot_data.append(
                    {
                        "Complexity": complexity.upper(),
                        "Metric": metric,
                        "Technology": "long-read",
                        "Value": ont_value,
                        "Significance": "",  # No significance for indels as we only compare ONT
                    }
                )
            except Exception as e:
                logger.error(
                    f"Error preparing data for {complexity} {metric}: {str(e)}"
                )
                continue

    return pl.DataFrame(plot_data)


def plot_indel_performance_metrics(
    plot_data_df: pl.DataFrame,
    figsize: Tuple[int, int] = (14, 6),
    dpi: int = 300,
    ylim: Tuple[float, float] = (0, 1),
    title: str = "Indel Performance",
    metrics: Tuple[str, ...] = ("Precision", "Sensitivity", "F-measure"),
    gs: Optional[gridspec.GridSpec] = None,
) -> Optional[plt.Figure]:
    """
    Create a performance plot for indel detection.

    Args:
        plot_data_df: Prepared data DataFrame from prepare_indel_performance_data
        figsize: Figure size in inches
        dpi: Figure resolution
        ylim: Y-axis limits
        title: Plot title
        metrics: Metrics to plot
        gs: GridSpec for subplot placement

    Returns:
        Optional[plt.Figure]: Figure object if gs is None, None otherwise
    """
    try:
        if gs is None:
            fig, axes = plt.subplots(1, 2, figsize=figsize, dpi=dpi)
        else:
            fig = plt.gcf()
            axes = [fig.add_subplot(gs[0, i]) for i in range(2)]

        complexities = ["HC", "LC"]
        for i, complexity in enumerate(complexities):
            complexity_data = plot_data_df.filter(pl.col("Complexity") == complexity)

            bars = sns.barplot(
                data=complexity_data,
                x="Metric",
                y="Value",
                errorbar=None,
                ax=axes[i],
            )

            # Add value labels on top of each bar
            for p in bars.patches:
                value = p.get_height()
                if value > 0:
                    axes[i].annotate(
                        f"{value:.3f}",
                        (p.get_x() + p.get_width() / 2.0, value),
                        ha="center",
                        va="bottom",
                        fontsize=8,
                        rotation=0,
                    )

            complexity_label = "High" if complexity == "HC" else "Low"
            axes[i].set_title(f"{complexity_label} Complexity", pad=15)
            axes[i].set_ylim(ylim)
            axes[i].set_xlabel("")
            axes[i].set_ylabel("Performance")

        if gs is None:
            plt.tight_layout()
            return fig
        return None

    except Exception as e:
        logger.error(f"Error creating performance plots: {str(e)}")
        raise


indel_performance_data = prepare_indel_performance_data(
    ont_stats=indel_ont_stats,
    metrics=("Precision", "Sensitivity", "F-measure"),
    complexities=("hc", "lc"),
)

indel_performance_plot = plot_indel_performance_metrics(
    plot_data_df=indel_performance_data,
)
No description has been provided for this image

2. Error Analysis¶

In [30]:
@dataclass
class IndelSizeMetrics:
    """Data class for indel size-based metrics."""

    size: int
    complexity: str
    true_pos_baseline: int
    true_pos_call: int
    false_pos: int
    false_neg: int
    precision: float
    sensitivity: float
    f_measure: float


def get_vcfeval_indel_size_paths(
    sample_id: str, size: int, complexity: str, base_dir: Path
) -> Tuple[Path, Path, Path, Path, Path]:
    """
    Generate paths for size-specific indel VCF evaluation files.

    Args:
        sample_id: Sample identifier
        size: Size of indels in base pairs
        complexity: Genomic complexity region (hc/lc)
        base_dir: Base directory for project data

    Returns:
        Tuple containing paths for summary.txt, tp-baseline.vcf.gz, tp.vcf.gz,
        fp.vcf.gz, and fn.vcf.gz
    """
    base_path = (
        base_dir
        / "output"
        / "indel"
        / "rtg_vcfeval"
        / complexity
        / f"{size}bp"
        / f"{sample_id}.indel"
    )

    return (
        base_path / "summary.txt",
        base_path / "tp-baseline.vcf.gz",
        base_path / "tp.vcf.gz",
        base_path / "fp.vcf.gz",
        base_path / "fn.vcf.gz",
    )


def collect_indel_size_metrics(
    sample_ids: pl.DataFrame,
    complexities: List[str],
    base_dir: Path,
) -> Dict[int, Dict[str, List[IndelSizeMetrics]]]:
    """
    Collect indel metrics stratified by size from RTG vcfeval output.

    Args:
        sample_ids: DataFrame containing sample ID mappings
        complexities: List of genomic complexity regions
        base_dir: Base directory containing VCF files

    Returns:
        Nested dictionary of metrics organized by size and complexity
    """
    metrics: Dict[int, Dict[str, List[IndelSizeMetrics]]] = defaultdict(
        lambda: {comp: [] for comp in complexities}
    )

    for row in sample_ids.iter_rows(named=True):
        sample_id = row["ont_id"]
        for complexity in complexities:
            indel_dir = base_dir / "output" / "indel" / "rtg_vcfeval" / complexity
            if not indel_dir.exists():
                continue

            for size_dir in indel_dir.glob("*bp"):
                try:
                    size = int(size_dir.name.replace("bp", ""))
                except ValueError:
                    continue

                summary_path = size_dir / f"{sample_id}.indel" / "summary.txt"

                if not summary_path.is_file():
                    logger.warning(
                        f"Summary file not found for {sample_id}, size {size}, {complexity}"
                    )
                    continue

                try:
                    summary = _read_rtg_summary(summary_path)
                    if summary:
                        metrics[size][complexity].append(
                            IndelSizeMetrics(
                                size=size, complexity=complexity, **summary
                            )
                        )
                except Exception as e:
                    logger.error(
                        f"Error processing summary for {sample_id}, size {size}, "
                        f"complexity {complexity}: {str(e)}"
                    )

    return metrics


def process_indel_size_metrics(
    indel_metrics: Dict[int, Dict[str, List[IndelSizeMetrics]]]
) -> Tuple[pl.DataFrame, pl.DataFrame]:
    """
    Process indel metrics and calculate size-based statistics.

    Args:
        indel_metrics: Nested dictionary of metrics by size and complexity

    Returns:
        Tuple containing raw metrics DataFrame and aggregated statistics DataFrame
    """
    all_data = []
    for size, size_data in indel_metrics.items():
        for complexity, metrics_list in size_data.items():
            for metric in metrics_list:
                all_data.append(
                    {
                        "size": metric.size,
                        "complexity": metric.complexity.upper(),
                        "precision": metric.precision,
                        "sensitivity": metric.sensitivity,
                        "f_measure": metric.f_measure,
                        "true_pos_baseline": metric.true_pos_baseline,
                        "true_pos_call": metric.true_pos_call,
                        "false_pos": metric.false_pos,
                        "false_neg": metric.false_neg,
                    }
                )

    metrics_df = pl.DataFrame(all_data)

    stats_df = metrics_df.group_by(["size", "complexity"]).agg(
        [
            pl.col("precision").mean().alias("precision_mean"),
            pl.col("precision").std().alias("precision_std"),
            pl.col("sensitivity").mean().alias("sensitivity_mean"),
            pl.col("sensitivity").std().alias("sensitivity_std"),
            pl.col("f_measure").mean().alias("f_measure_mean"),
            pl.col("f_measure").std().alias("f_measure_std"),
        ]
    )

    return metrics_df, stats_df


def prepare_indel_size_performance_data(
    metrics_df: pl.DataFrame,
    metrics: List[str] = ["precision", "sensitivity", "f_measure"],
) -> pl.DataFrame:
    """
    Prepare indel performance data by size for visualization.

    Args:
        metrics_df: DataFrame containing raw metrics for each sample
        metrics: List of performance metrics to analyze

    Returns:
        DataFrame containing processed performance data
    """
    plot_data = []

    for row in metrics_df.iter_rows(named=True):
        for metric in metrics:
            plot_data.append(
                {
                    "size": row["size"],
                    "complexity": row["complexity"],
                    "metric": metric.capitalize(),
                    "value": row[metric],
                }
            )

    return pl.DataFrame(plot_data)


def stretched_exponential(x, a, b, c, β):
    """
    Stretched exponential function: a*exp(-(x/b)^β) + c
    Parameters:
        a: amplitude
        b: characteristic time/length scale
        c: vertical offset
        β: stretching exponent (controls decay rate variation)
    """
    return a * np.exp(-((x / b) ** β)) + c


def fit_and_get_ci(x, y, func=stretched_exponential, p0=None):
    """Fit curve and calculate 95% confidence intervals using median values"""
    x = np.array(x)
    y = np.array(y)

    sort_idx = np.argsort(x)
    x = x[sort_idx]
    y = y[sort_idx]

    p0 = p0 if p0 is not None else [1.0, 10.0, 0.0, 1.0]

    try:
        popt, pcov = curve_fit(
            func,
            x,
            y,
            p0=p0,
            maxfev=5000,
            bounds=([0, 0, -np.inf, 0], [np.inf, np.inf, np.inf, 10]),
        )

        perr = np.sqrt(np.diag(pcov))

        x_smooth = np.linspace(1, 50, 100)
        y_fit = func(x_smooth, *popt)

        y_err = np.zeros(len(x_smooth))
        for i in range(len(x_smooth)):
            jac = np.zeros(4)
            dx = 1e-6
            for j in range(4):
                params = list(popt)
                params[j] += dx
                jac[j] = (func(x_smooth[i], *params) - y_fit[i]) / dx
            y_err[i] = np.sqrt(np.sum((jac * perr) ** 2))

        return x_smooth, y_fit, y_err, popt, perr

    except Exception as e:
        logger.error(f"Error in curve fitting: {str(e)}")
        x_smooth = np.linspace(1, 50, 100)
        y_fit = np.zeros_like(x_smooth)
        y_err = np.zeros_like(x_smooth)
        popt = [0, 0, 0, 0]
        perr = [0, 0, 0, 0]

        return x_smooth, y_fit, y_err, popt, perr


def analyze_indel_size_performance(
    metrics_df: pl.DataFrame,
    indel_metrics: Dict[int, Dict[str, List[IndelSizeMetrics]]],
    metrics: List[str] = ["precision", "sensitivity", "f_measure"],
) -> Tuple[pl.DataFrame, Dict]:
    """
    Analyze indel performance by size categories and fit regression curves.
    """
    # Create size categories
    metrics_df = metrics_df.with_columns(
        pl.when(pl.col("size") <= 5)
        .then(pl.lit("1-5bp"))
        .when(pl.col("size") <= 10)
        .then(pl.lit("6-10bp"))
        .when(pl.col("size") <= 20)
        .then(pl.lit("11-20bp"))
        .when(pl.col("size") <= 50)
        .then(pl.lit("21-50bp"))
        .alias("size_category")
    )

    # Count variants by size category and complexity
    size_counts = {complexity: defaultdict(int) for complexity in ["hc", "lc"]}
    for size, size_data in indel_metrics.items():
        size_int = int(size)
        for complexity, metrics_list in size_data.items():
            for metric in metrics_list:
                if size_int <= 5:
                    size_cat = "1-5bp"
                elif size_int <= 10:
                    size_cat = "6-10bp"
                elif size_int <= 20:
                    size_cat = "11-20bp"
                elif size_int <= 50:
                    size_cat = "21-50bp"
                else:
                    continue

                size_counts[complexity.lower()][size_cat] += (
                    metric.true_pos_baseline + metric.false_pos + metric.false_neg
                )

    # Create summary statistics
    summary_stats = []
    for complexity in ["HC", "LC"]:
        for size_cat in ["1-5bp", "6-10bp", "11-20bp", "21-50bp"]:
            group = metrics_df.filter(
                (pl.col("complexity") == complexity)
                & (pl.col("size_category") == size_cat)
            )

            complexity_key = complexity.lower()
            total_count = sum(size_counts[complexity_key].values())
            category_count = size_counts[complexity_key].get(size_cat, 0)

            stats_dict = {
                "complexity": complexity,
                "size_category": size_cat,
                "count": category_count,
                "proportion": (
                    (category_count / total_count) * 100 if total_count > 0 else 0
                ),
            }

            for metric in metrics:
                metric_values = group.get_column(metric)
                if len(metric_values) > 0:
                    metric_median = np.median(metric_values)
                    q1 = np.percentile(metric_values, 25)
                    q3 = np.percentile(metric_values, 75)
                    iqr = q3 - q1
                else:
                    metric_median = float("nan")
                    iqr = float("nan")

                stats_dict[f"{metric}_median"] = metric_median
                stats_dict[f"{metric}_iqr"] = iqr

            summary_stats.append(stats_dict)

    summary_df = pl.DataFrame(summary_stats)

    # Round values
    summary_df = summary_df.with_columns(
        [
            pl.col("proportion").round(1).alias("proportion"),
            *[
                pl.col(f"{metric}_median").round(3).alias(f"{metric}_median")
                for metric in metrics
            ],
            *[
                pl.col(f"{metric}_iqr").round(3).alias(f"{metric}_iqr")
                for metric in metrics
            ],
        ]
    )

    # Fit regression curves
    regression_results = {}

    for metric in metrics:
        metric_cap = metric.capitalize()
        regression_results[metric_cap] = {}

        for complexity in ["HC", "LC"]:
            data = metrics_df.filter(pl.col("complexity") == complexity)

            # Calculate median by size
            median_by_size = data.group_by("size").agg(
                pl.col(metric).median().alias("median_value")
            )

            x = median_by_size.get_column("size").to_numpy()
            y = median_by_size.get_column("median_value").to_numpy()

            x_smooth, y_fit, y_err, popt, perr = fit_and_get_ci(
                x, y, stretched_exponential, p0=[1.0, 10.0, 0.0, 1.0]
            )

            regression_results[metric_cap][complexity] = {
                "x_smooth": x_smooth,
                "y_fit": y_fit,
                "y_err": y_err,
                "popt": popt,
                "perr": perr,
            }

            # Calculate R-squared and p-value
            if len(x) > 4:  # Need more data points than parameters
                residuals = y - stretched_exponential(x, *popt)
                ss_res = np.sum(residuals**2)
                ss_tot = np.sum((y - np.mean(y)) ** 2)
                r_squared = 1 - (ss_res / ss_tot)

                n = len(x)
                p = 4  # Number of parameters
                f_stat = (r_squared / p) / ((1 - r_squared) / (n - p - 1))
                p_value = 1 - stats.f.cdf(f_stat, p, n - p - 1)

                logger.info(f"{metric_cap} - {complexity} Region:")
                logger.info(f"R-squared: {r_squared:.3f} (p-value: {p_value:.3e})")

    return summary_df, regression_results


def plot_indel_size_performance(
    plot_data: pl.DataFrame,
    regression_results: Dict,
    figsize: Tuple[int, int] = (15, 10),
    dpi: int = 300,
    gs: Optional[gridspec.GridSpec] = None,
) -> Optional[Tuple[plt.Figure, np.ndarray]]:
    """
    Create performance plots for indel size analysis.

    Args:
        plot_data (pl.DataFrame): DataFrame containing prepared performance data.
        regression_results (Dict): Dictionary with regression fit data for each metric.
        figsize (Tuple[int, int], optional): Figure size in inches. Defaults to (15, 10).
        dpi (int, optional): Figure resolution. Defaults to 300.
        gs (Optional[gridspec.GridSpec], optional): GridSpec for subplot placement.
            If None, creates a standalone figure. Defaults to None.

    Returns:
        Optional[Tuple[plt.Figure, np.ndarray]]: If `gs` is None, returns the figure and axes array.
        If `gs` is provided, returns None (the plots are added as subfigures).
    """
    metrics = ["Precision", "Sensitivity", "F_measure"]

    if gs is None:
        fig, axes = plt.subplots(3, 1, figsize=figsize, dpi=dpi)
    else:
        fig = plt.gcf()
        axes = np.array([fig.add_subplot(gs[i, 0]) for i in range(3)])

    colors = sns.color_palette()

    for i, metric in enumerate(metrics):
        filtered = plot_data.filter(pl.col("metric") == metric).with_columns(
            pl.col("size").cast(pl.Int64)
        )
        sizes = filtered["size"].to_list()
        values = filtered["value"].to_list()
        complexities = filtered["complexity"].to_list()

        sns.boxplot(x=sizes, y=values, hue=complexities, ax=axes[i], width=0.8)

        complexity_labels = {"HC": "High Complexity", "LC": "Low Complexity"}

        for j, complexity in enumerate(["HC", "LC"]):
            if (
                metric in regression_results
                and complexity in regression_results[metric]
            ):
                results = regression_results[metric][complexity]

                axes[i].plot(
                    results["x_smooth"],
                    results["y_fit"],
                    "-",
                    color=colors[j],
                    label=f"{complexity_labels[complexity]} line of best fit",
                    alpha=0.6,
                    linewidth=2,
                )

        axes[i].set_xticks(np.arange(-0.5, 50.5, 1), minor=True)
        axes[i].grid(axis="x", linestyle="-", alpha=0.7, which="minor")
        axes[i].set_axisbelow(True)
        axes[i].set_title("F-measure" if metric == "F_measure" else metric)
        axes[i].set_xlabel("Indel Size (bp)")
        axes[i].set_ylabel("Performance")
        axes[i].set_xlim(-0.5, 49.5)
        axes[i].set_ylim(0, 1)

        if i == 0:
            handles, labels = axes[i].get_legend_handles_labels()
            labels = [
                "High Complexity" if l == "HC" else "Low Complexity" if l == "LC" else l
                for l in labels
            ]
            if gs is None:
                legend = axes[i].legend(
                    handles,
                    labels,
                    title="Region",
                    bbox_to_anchor=(1, 1),
                    loc="upper left",
                )
                legend.get_title().set_weight("bold")
            else:
                legend = axes[i].legend(
                    handles,
                    labels,
                    title="Region",
                    loc="lower left",
                )
                legend.get_title().set_weight("bold")
        else:
            legend = axes[i].get_legend()
            if legend is not None:
                legend.remove()

    if gs is None:
        plt.tight_layout()
        return fig
    return None

    return gs


indel_size_metrics = collect_indel_size_metrics(
    sample_ids, list(indel_config.complexities), indel_config.base_dir
)

indel_raw_metrics_df, indel_aggregated_stats_df = process_indel_size_metrics(
    indel_size_metrics
)

indel_size_plot_data = prepare_indel_size_performance_data(indel_raw_metrics_df)

indel_size_perf_summary_df, indel_size_perf_regression_results = (
    analyze_indel_size_performance(indel_raw_metrics_df, indel_size_metrics)
)

indel_size_perf_plot = plot_indel_size_performance(
    indel_size_plot_data, indel_size_perf_regression_results
)

with pl.Config(tbl_rows=len(indel_size_perf_summary_df)):
    display(indel_size_perf_summary_df)
__main__ - INFO - Precision - HC Region:
__main__ - INFO - R-squared: 0.849 (p-value: 1.110e-16)
__main__ - INFO - Precision - LC Region:
__main__ - INFO - R-squared: 0.690 (p-value: 5.951e-11)
__main__ - INFO - Sensitivity - HC Region:
__main__ - INFO - R-squared: 0.670 (p-value: 2.362e-10)
__main__ - INFO - Sensitivity - LC Region:
__main__ - INFO - R-squared: -0.000 (p-value: 1.000e+00)
__main__ - INFO - F_measure - HC Region:
__main__ - INFO - R-squared: 0.913 (p-value: 1.110e-16)
__main__ - INFO - F_measure - LC Region:
__main__ - INFO - R-squared: 0.575 (p-value: 6.025e-08)
matplotlib.category - INFO - Using categorical units to plot a list of strings that are all parsable as floats or dates. If these strings should be plotted as numbers, cast to the appropriate data type before plotting.
matplotlib.category - INFO - Using categorical units to plot a list of strings that are all parsable as floats or dates. If these strings should be plotted as numbers, cast to the appropriate data type before plotting.
matplotlib.category - INFO - Using categorical units to plot a list of strings that are all parsable as floats or dates. If these strings should be plotted as numbers, cast to the appropriate data type before plotting.
matplotlib.category - INFO - Using categorical units to plot a list of strings that are all parsable as floats or dates. If these strings should be plotted as numbers, cast to the appropriate data type before plotting.
matplotlib.category - INFO - Using categorical units to plot a list of strings that are all parsable as floats or dates. If these strings should be plotted as numbers, cast to the appropriate data type before plotting.
matplotlib.category - INFO - Using categorical units to plot a list of strings that are all parsable as floats or dates. If these strings should be plotted as numbers, cast to the appropriate data type before plotting.
shape: (8, 10)
complexitysize_categorycountproportionprecision_medianprecision_iqrsensitivity_mediansensitivity_iqrf_measure_medianf_measure_iqr
strstri64f64f64f64f64f64f64f64
"HC""1-5bp"331418889.50.8260.0710.9180.040.8690.066
"HC""6-10bp"2153295.80.7720.030.90.0670.8340.039
"HC""11-20bp"1165023.10.7980.0370.9320.0630.8620.044
"HC""21-50bp"565361.50.6240.1840.8850.0970.7240.156
"LC""1-5bp"1277495483.50.4450.1160.4830.1090.4540.124
"LC""6-10bp"13742279.00.4410.1730.590.1730.5010.181
"LC""11-20bp"7512504.90.4070.150.5770.1560.4780.16
"LC""21-50bp"3921602.60.250.0930.5710.0850.350.093
No description has been provided for this image

3. Size Distribution Analysis¶

In [31]:
def _classify_indel(ref: str, alt: str) -> Tuple[str, int]:
    """
    Classify an indel as insertion or deletion and determine its length.

    Args:
        ref: Reference allele
        alt: Alternate allele

    Returns:
        Tuple containing indel category (insertion/deletion) and length
    """
    indel_type = "insertion" if len(alt) > len(ref) else "deletion"
    indel_length = abs(len(alt) - len(ref))
    return indel_type, indel_length


def analyze_indel_size_distribution(
    sample_ids: pl.DataFrame,
    complexities: List[str],
) -> Dict[str, Dict[str, Dict[str, Dict[str, int]]]]:
    """
    Analyze indel size distribution using original VCF files.

    Args:
        sample_ids: DataFrame containing sample ID mappings.
        complexities: List of genomic complexity regions.

    Returns:
        Nested dictionary containing size distributions organized by:
        - First level: complexity region (hc/lc).
        - Second level: technology (ont/illumina).
        - Third level: indel type (insertion/deletion).
        - Fourth level: dictionary mapping sizes to counts.
    """
    distributions: Dict[str, Dict[str, Dict[str, DefaultDict[str, int]]]] = {
        comp: {
            tech: {
                indel_type: defaultdict(int) for indel_type in ["insertion", "deletion"]
            }
            for tech in ["ont", "illumina"]
        }
        for comp in complexities
    }

    for row in sample_ids.iter_rows(named=True):
        sample_id = row["ont_id"]

        for complexity in complexities:
            base_dir = (
                indel_config.base_dir / "output" / "indel" / "rtg_vcfeval" / complexity
            )

            # Extract indel sizes dynamically from available directories
            for indel_size_dir in base_dir.iterdir():
                if not indel_size_dir.is_dir() or not indel_size_dir.name.endswith(
                    "bp"
                ):
                    continue  # Skip non-indel size directories

                vcf_dir = indel_size_dir / f"{sample_id}.indel"

                query_vcf = vcf_dir / "query.vcf.gz"
                truth_vcf = vcf_dir / "truth.vcf.gz"

                for vcf_path, tech in [(query_vcf, "ont"), (truth_vcf, "illumina")]:
                    if not vcf_path.exists():
                        logger.warning(f"VCF file not found: {vcf_path}")
                        continue

                    try:
                        with pysam.VariantFile(str(vcf_path)) as vcf:
                            for record in vcf:
                                if len(record.ref) == 1 and len(record.alts[0]) == 1:
                                    continue

                                indel_type, length = _classify_indel(
                                    record.ref, record.alts[0]
                                )
                                distributions[complexity][tech][indel_type][length] += 1

                    except Exception as e:
                        logger.error(f"Error processing {vcf_path}: {str(e)}")

    return distributions


def prepare_indel_size_distribution_data(
    distributions: Dict[str, Dict[str, Dict[str, Dict[str, int]]]]
) -> pl.DataFrame:
    """
    Prepare indel size distribution data for visualization.

    Args:
        distributions: Nested dictionary containing size distributions organized by
                      complexity, technology, and indel type

    Returns:
        DataFrame containing processed distribution data with columns for Complexity,
        Technology, Indel Type, Size, and Percentage
    """
    plot_data = []
    tech_display = {"ont": "ONT", "illumina": "Illumina"}
    complexity_display = {"hc": "High Complexity", "lc": "Low Complexity"}

    for complexity, tech_data in distributions.items():
        for tech, type_data in tech_data.items():
            for indel_type, size_data in type_data.items():
                total = sum(size_data.values())
                if total == 0:
                    continue

                max_size = max(size_data.keys()) if size_data else 0
                for size in range(1, max_size + 1):
                    count = size_data.get(size, 0)
                    plot_data.append(
                        {
                            "Complexity": complexity_display[complexity],
                            "Technology": tech_display[tech],
                            "Indel Type": indel_type,
                            "Size": size,
                            "Percentage": (count / total) * 100 if total > 0 else 0,
                        }
                    )

    return pl.DataFrame(plot_data)


def plot_indel_size_distributions(
    df: pl.DataFrame,
    figsize: Tuple[int, int] = (12, 8),
    dpi: int = 300,
    gs: Optional[gridspec.GridSpec] = None,
) -> Optional[plt.Figure]:
    """
    Create plots showing indel size distributions across technologies and complexities.

    Args:
        df: DataFrame containing distribution data
        figsize: Figure size in inches
        dpi: Figure resolution
        gs: GridSpec for subplot placement. If None, creates a standalone figure.

    Returns:
        Optional[plt.Figure]: If gs is None, returns the figure.
        If gs is provided, returns None (plots are added as subfigures).
    """
    try:
        if gs is None:
            fig, axes = plt.subplots(2, 2, figsize=figsize, dpi=dpi)
        else:
            fig = plt.gcf()
            axes = np.array(
                [[plt.subplot(gs[i, j]) for j in range(2)] for i in range(2)]
            )

        indel_types = ["insertion", "deletion"]
        complexities = ["High Complexity", "Low Complexity"]
        colors = sns.color_palette("colorblind")

        for i, indel_type in enumerate(indel_types):
            for j, complexity in enumerate(complexities):
                subset = df.filter(
                    (pl.col("Indel Type") == indel_type)
                    & (pl.col("Complexity") == complexity)
                )

                for k, tech in enumerate(["ONT", "Illumina"]):
                    tech_data = subset.filter(pl.col("Technology") == tech)

                    if tech_data.height > 0:
                        tech_label = "Long-read" if tech == "ONT" else "Short-read"

                        sns.lineplot(
                            data=tech_data,
                            x="Size",
                            y="Percentage",
                            label=tech_label,
                            ax=axes[i, j],
                            marker="o",
                            markersize=4,
                            color=colors[k],
                        )

                axes[i, j].set_title(f"{complexity} - {indel_type.capitalize()}s")
                axes[i, j].set_xlabel("Indel Size (bp)")
                axes[i, j].set_ylabel("Proportion of Indels (%)")

                # Check if we're in the top-right plot (0, 1) and handle legend
                if (i, j) == (0, 1):
                    if gs is None:
                        axes[i, j].legend(
                            title="Technology", bbox_to_anchor=(1, 1), loc="upper left"
                        )
                        axes[i, j].get_legend().get_title().set_weight("bold")
                    else:
                        axes[i, j].legend(title="Technology", loc="upper right")
                        axes[i, j].get_legend().get_title().set_weight("bold")
                else:
                    # Remove legend for all other subplots
                    legend = axes[i, j].get_legend()
                    if legend is not None:
                        legend.remove()

        if gs is None:
            plt.tight_layout()
            return fig
        return None

    except Exception as e:
        logger.error(f"Error creating indel size distribution plots: {str(e)}")
        return None


def compare_size_distributions(
    distributions: Dict[str, Dict[str, Dict[str, Dict[str, int]]]]
) -> pl.DataFrame:
    """
    Perform statistical comparison of indel size distributions between ONT and Illumina
    using the Kolmogorov-Smirnov test.

    Args:
        distributions: Nested dictionary containing indel size distributions organized by
                      complexity, technology, and indel type

    Returns:
        DataFrame containing statistical test results with columns:
        - Complexity: Genomic complexity region (High/Low Complexity)
        - Indel Type: Type of indel (Insertion/Deletion)
        - KS Statistic: Kolmogorov-Smirnov test statistic
        - p-value: Raw p-value from KS test
        - Adjusted p-value: FDR-corrected p-value
    """
    results = []

    try:
        for complexity in distributions:
            for indel_type in ["insertion", "deletion"]:
                # Expand distribution to actual size arrays
                ont_sizes = []
                illumina_sizes = []

                for size, count in distributions[complexity]["ont"][indel_type].items():
                    ont_sizes.extend([size] * count)

                for size, count in distributions[complexity]["illumina"][
                    indel_type
                ].items():
                    illumina_sizes.extend([size] * count)

                if not ont_sizes or not illumina_sizes:
                    logger.warning(
                        f"Skipping KS test for {complexity}/{indel_type} due to empty data"
                    )
                    continue

                ks_stat, p_val = stats.ks_2samp(ont_sizes, illumina_sizes)

                results.append(
                    {
                        "Complexity": (
                            "High Complexity"
                            if complexity == "hc"
                            else "Low Complexity"
                        ),
                        "Indel Type": indel_type.capitalize(),
                        "KS Statistic": ks_stat,
                        "p-value": p_val,
                    }
                )

        if not results:
            logger.warning("No data available for statistical comparison")
            return pl.DataFrame(
                schema={
                    "Complexity": pl.Utf8,
                    "Indel Type": pl.Utf8,
                    "KS Statistic": pl.Float64,
                    "p-value": pl.Float64,
                    "Adjusted p-value": pl.Float64,
                }
            )

        results_df = pl.DataFrame(results)

        # Extract p-values and apply FDR correction
        p_values = results_df.get_column("p-value").to_numpy()
        _, adjusted_p, _, _ = multipletests(p_values, method="fdr_bh")

        # Add adjusted p-values to the DataFrame
        results_df = results_df.with_columns(pl.Series("Adjusted p-value", adjusted_p))

        return results_df

    except Exception as e:
        logger.error(f"Error performing statistical comparison: {str(e)}")
        return pl.DataFrame(
            schema={
                "Complexity": pl.Utf8,
                "Indel Type": pl.Utf8,
                "KS Statistic": pl.Float64,
                "p-value": pl.Float64,
                "Adjusted p-value": pl.Float64,
            }
        )


indel_distributions = analyze_indel_size_distribution(
    sample_ids, list(indel_config.complexities)
)

distribution_data = prepare_indel_size_distribution_data(indel_distributions)

indel_size_dist_plot = plot_indel_size_distributions(distribution_data)

statistical_results = compare_size_distributions(indel_distributions)

print("\nStatistical Test Results (Kolmogorov-Smirnov test with FDR correction):")
with pl.Config(tbl_rows=len(statistical_results)):
    display(statistical_results)
[W::hts_idx_load3] The index file is older than the data file: /scratch/prj/ppn_als_longread/ont-benchmark/output/indel/rtg_vcfeval/hc/46bp/A046_12.indel/query.vcf.gz.tbi
[W::hts_idx_load3] The index file is older than the data file: /scratch/prj/ppn_als_longread/ont-benchmark/output/indel/rtg_vcfeval/hc/30bp/A046_12.indel/truth.vcf.gz.tbi
[W::hts_idx_load3] The index file is older than the data file: /scratch/prj/ppn_als_longread/ont-benchmark/output/indel/rtg_vcfeval/hc/7bp/A046_12.indel/truth.vcf.gz.tbi
[W::hts_idx_load3] The index file is older than the data file: /scratch/prj/ppn_als_longread/ont-benchmark/output/indel/rtg_vcfeval/hc/42bp/A046_12.indel/query.vcf.gz.tbi
[W::hts_idx_load3] The index file is older than the data file: /scratch/prj/ppn_als_longread/ont-benchmark/output/indel/rtg_vcfeval/lc/5bp/A046_12.indel/truth.vcf.gz.tbi
[W::hts_idx_load3] The index file is older than the data file: /scratch/prj/ppn_als_longread/ont-benchmark/output/indel/rtg_vcfeval/lc/26bp/A046_12.indel/query.vcf.gz.tbi
[W::hts_idx_load3] The index file is older than the data file: /scratch/prj/ppn_als_longread/ont-benchmark/output/indel/rtg_vcfeval/lc/26bp/A046_12.indel/truth.vcf.gz.tbi
[W::hts_idx_load3] The index file is older than the data file: /scratch/prj/ppn_als_longread/ont-benchmark/output/indel/rtg_vcfeval/lc/20bp/A046_12.indel/query.vcf.gz.tbi
[W::hts_idx_load3] The index file is older than the data file: /scratch/prj/ppn_als_longread/ont-benchmark/output/indel/rtg_vcfeval/lc/47bp/A046_12.indel/truth.vcf.gz.tbi
[W::hts_idx_load3] The index file is older than the data file: /scratch/prj/ppn_als_longread/ont-benchmark/output/indel/rtg_vcfeval/lc/1bp/A046_12.indel/query.vcf.gz.tbi
[W::hts_idx_load3] The index file is older than the data file: /scratch/prj/ppn_als_longread/ont-benchmark/output/indel/rtg_vcfeval/lc/1bp/A046_12.indel/truth.vcf.gz.tbi
[W::hts_idx_load3] The index file is older than the data file: /scratch/prj/ppn_als_longread/ont-benchmark/output/indel/rtg_vcfeval/hc/33bp/A048_09.indel/truth.vcf.gz.tbi
[W::hts_idx_load3] The index file is older than the data file: /scratch/prj/ppn_als_longread/ont-benchmark/output/indel/rtg_vcfeval/hc/22bp/A048_09.indel/query.vcf.gz.tbi
[W::hts_idx_load3] The index file is older than the data file: /scratch/prj/ppn_als_longread/ont-benchmark/output/indel/rtg_vcfeval/hc/19bp/A048_09.indel/truth.vcf.gz.tbi
[W::hts_idx_load3] The index file is older than the data file: /scratch/prj/ppn_als_longread/ont-benchmark/output/indel/rtg_vcfeval/hc/13bp/A048_09.indel/truth.vcf.gz.tbi
[W::hts_idx_load3] The index file is older than the data file: /scratch/prj/ppn_als_longread/ont-benchmark/output/indel/rtg_vcfeval/lc/18bp/A048_09.indel/query.vcf.gz.tbi
[W::hts_idx_load3] The index file is older than the data file: /scratch/prj/ppn_als_longread/ont-benchmark/output/indel/rtg_vcfeval/lc/5bp/A048_09.indel/truth.vcf.gz.tbi
[W::hts_idx_load3] The index file is older than the data file: /scratch/prj/ppn_als_longread/ont-benchmark/output/indel/rtg_vcfeval/lc/39bp/A048_09.indel/truth.vcf.gz.tbi
[W::hts_idx_load3] The index file is older than the data file: /scratch/prj/ppn_als_longread/ont-benchmark/output/indel/rtg_vcfeval/lc/2bp/A048_09.indel/truth.vcf.gz.tbi
[W::hts_idx_load3] The index file is older than the data file: /scratch/prj/ppn_als_longread/ont-benchmark/output/indel/rtg_vcfeval/lc/26bp/A048_09.indel/query.vcf.gz.tbi
[W::hts_idx_load3] The index file is older than the data file: /scratch/prj/ppn_als_longread/ont-benchmark/output/indel/rtg_vcfeval/lc/19bp/A048_09.indel/query.vcf.gz.tbi
[W::hts_idx_load3] The index file is older than the data file: /scratch/prj/ppn_als_longread/ont-benchmark/output/indel/rtg_vcfeval/lc/1bp/A048_09.indel/truth.vcf.gz.tbi
[W::hts_idx_load3] The index file is older than the data file: /scratch/prj/ppn_als_longread/ont-benchmark/output/indel/rtg_vcfeval/hc/34bp/A079_07.indel/truth.vcf.gz.tbi
[W::hts_idx_load3] The index file is older than the data file: /scratch/prj/ppn_als_longread/ont-benchmark/output/indel/rtg_vcfeval/hc/44bp/A079_07.indel/query.vcf.gz.tbi
[W::hts_idx_load3] The index file is older than the data file: /scratch/prj/ppn_als_longread/ont-benchmark/output/indel/rtg_vcfeval/lc/32bp/A079_07.indel/truth.vcf.gz.tbi
[W::hts_idx_load3] The index file is older than the data file: /scratch/prj/ppn_als_longread/ont-benchmark/output/indel/rtg_vcfeval/lc/39bp/A079_07.indel/query.vcf.gz.tbi
[W::hts_idx_load3] The index file is older than the data file: /scratch/prj/ppn_als_longread/ont-benchmark/output/indel/rtg_vcfeval/lc/39bp/A079_07.indel/truth.vcf.gz.tbi
[W::hts_idx_load3] The index file is older than the data file: /scratch/prj/ppn_als_longread/ont-benchmark/output/indel/rtg_vcfeval/lc/42bp/A079_07.indel/truth.vcf.gz.tbi
[W::hts_idx_load3] The index file is older than the data file: /scratch/prj/ppn_als_longread/ont-benchmark/output/indel/rtg_vcfeval/lc/1bp/A079_07.indel/query.vcf.gz.tbi
[W::hts_idx_load3] The index file is older than the data file: /scratch/prj/ppn_als_longread/ont-benchmark/output/indel/rtg_vcfeval/lc/1bp/A079_07.indel/truth.vcf.gz.tbi
[W::hts_idx_load3] The index file is older than the data file: /scratch/prj/ppn_als_longread/ont-benchmark/output/indel/rtg_vcfeval/hc/11bp/A081_91.indel/truth.vcf.gz.tbi
[W::hts_idx_load3] The index file is older than the data file: /scratch/prj/ppn_als_longread/ont-benchmark/output/indel/rtg_vcfeval/hc/8bp/A081_91.indel/query.vcf.gz.tbi
[W::hts_idx_load3] The index file is older than the data file: /scratch/prj/ppn_als_longread/ont-benchmark/output/indel/rtg_vcfeval/hc/4bp/A081_91.indel/truth.vcf.gz.tbi
[W::hts_idx_load3] The index file is older than the data file: /scratch/prj/ppn_als_longread/ont-benchmark/output/indel/rtg_vcfeval/lc/40bp/A081_91.indel/truth.vcf.gz.tbi
[W::hts_idx_load3] The index file is older than the data file: /scratch/prj/ppn_als_longread/ont-benchmark/output/indel/rtg_vcfeval/lc/1bp/A081_91.indel/query.vcf.gz.tbi
[W::hts_idx_load3] The index file is older than the data file: /scratch/prj/ppn_als_longread/ont-benchmark/output/indel/rtg_vcfeval/lc/1bp/A081_91.indel/truth.vcf.gz.tbi
[W::hts_idx_load3] The index file is older than the data file: /scratch/prj/ppn_als_longread/ont-benchmark/output/indel/rtg_vcfeval/hc/24bp/A085_00.indel/query.vcf.gz.tbi
[W::hts_idx_load3] The index file is older than the data file: /scratch/prj/ppn_als_longread/ont-benchmark/output/indel/rtg_vcfeval/hc/24bp/A085_00.indel/truth.vcf.gz.tbi
[W::hts_idx_load3] The index file is older than the data file: /scratch/prj/ppn_als_longread/ont-benchmark/output/indel/rtg_vcfeval/hc/45bp/A085_00.indel/query.vcf.gz.tbi
[W::hts_idx_load3] The index file is older than the data file: /scratch/prj/ppn_als_longread/ont-benchmark/output/indel/rtg_vcfeval/hc/25bp/A085_00.indel/truth.vcf.gz.tbi
[W::hts_idx_load3] The index file is older than the data file: /scratch/prj/ppn_als_longread/ont-benchmark/output/indel/rtg_vcfeval/lc/37bp/A085_00.indel/query.vcf.gz.tbi
[W::hts_idx_load3] The index file is older than the data file: /scratch/prj/ppn_als_longread/ont-benchmark/output/indel/rtg_vcfeval/lc/1bp/A085_00.indel/query.vcf.gz.tbi
[W::hts_idx_load3] The index file is older than the data file: /scratch/prj/ppn_als_longread/ont-benchmark/output/indel/rtg_vcfeval/lc/1bp/A085_00.indel/truth.vcf.gz.tbi
[W::hts_idx_load3] The index file is older than the data file: /scratch/prj/ppn_als_longread/ont-benchmark/output/indel/rtg_vcfeval/hc/17bp/A097_92.indel/query.vcf.gz.tbi
[W::hts_idx_load3] The index file is older than the data file: /scratch/prj/ppn_als_longread/ont-benchmark/output/indel/rtg_vcfeval/hc/1bp/A097_92.indel/truth.vcf.gz.tbi
[W::hts_idx_load3] The index file is older than the data file: /scratch/prj/ppn_als_longread/ont-benchmark/output/indel/rtg_vcfeval/lc/37bp/A097_92.indel/truth.vcf.gz.tbi
[W::hts_idx_load3] The index file is older than the data file: /scratch/prj/ppn_als_longread/ont-benchmark/output/indel/rtg_vcfeval/lc/33bp/A097_92.indel/truth.vcf.gz.tbi
[W::hts_idx_load3] The index file is older than the data file: /scratch/prj/ppn_als_longread/ont-benchmark/output/indel/rtg_vcfeval/lc/18bp/A097_92.indel/truth.vcf.gz.tbi
[W::hts_idx_load3] The index file is older than the data file: /scratch/prj/ppn_als_longread/ont-benchmark/output/indel/rtg_vcfeval/lc/36bp/A097_92.indel/truth.vcf.gz.tbi
[W::hts_idx_load3] The index file is older than the data file: /scratch/prj/ppn_als_longread/ont-benchmark/output/indel/rtg_vcfeval/lc/35bp/A097_92.indel/truth.vcf.gz.tbi
[W::hts_idx_load3] The index file is older than the data file: /scratch/prj/ppn_als_longread/ont-benchmark/output/indel/rtg_vcfeval/lc/39bp/A097_92.indel/truth.vcf.gz.tbi
[W::hts_idx_load3] The index file is older than the data file: /scratch/prj/ppn_als_longread/ont-benchmark/output/indel/rtg_vcfeval/lc/29bp/A097_92.indel/truth.vcf.gz.tbi
[W::hts_idx_load3] The index file is older than the data file: /scratch/prj/ppn_als_longread/ont-benchmark/output/indel/rtg_vcfeval/lc/2bp/A097_92.indel/truth.vcf.gz.tbi
[W::hts_idx_load3] The index file is older than the data file: /scratch/prj/ppn_als_longread/ont-benchmark/output/indel/rtg_vcfeval/lc/26bp/A097_92.indel/query.vcf.gz.tbi
[W::hts_idx_load3] The index file is older than the data file: /scratch/prj/ppn_als_longread/ont-benchmark/output/indel/rtg_vcfeval/lc/43bp/A097_92.indel/truth.vcf.gz.tbi
[W::hts_idx_load3] The index file is older than the data file: /scratch/prj/ppn_als_longread/ont-benchmark/output/indel/rtg_vcfeval/hc/39bp/A149_01.indel/query.vcf.gz.tbi
[W::hts_idx_load3] The index file is older than the data file: /scratch/prj/ppn_als_longread/ont-benchmark/output/indel/rtg_vcfeval/hc/2bp/A149_01.indel/truth.vcf.gz.tbi
[W::hts_idx_load3] The index file is older than the data file: /scratch/prj/ppn_als_longread/ont-benchmark/output/indel/rtg_vcfeval/hc/26bp/A149_01.indel/query.vcf.gz.tbi
[W::hts_idx_load3] The index file is older than the data file: /scratch/prj/ppn_als_longread/ont-benchmark/output/indel/rtg_vcfeval/hc/1bp/A149_01.indel/truth.vcf.gz.tbi
[W::hts_idx_load3] The index file is older than the data file: /scratch/prj/ppn_als_longread/ont-benchmark/output/indel/rtg_vcfeval/lc/24bp/A149_01.indel/truth.vcf.gz.tbi
[W::hts_idx_load3] The index file is older than the data file: /scratch/prj/ppn_als_longread/ont-benchmark/output/indel/rtg_vcfeval/lc/26bp/A149_01.indel/query.vcf.gz.tbi
[W::hts_idx_load3] The index file is older than the data file: /scratch/prj/ppn_als_longread/ont-benchmark/output/indel/rtg_vcfeval/hc/38bp/A153_01.indel/truth.vcf.gz.tbi
[W::hts_idx_load3] The index file is older than the data file: /scratch/prj/ppn_als_longread/ont-benchmark/output/indel/rtg_vcfeval/hc/27bp/A153_01.indel/query.vcf.gz.tbi
[W::hts_idx_load3] The index file is older than the data file: /scratch/prj/ppn_als_longread/ont-benchmark/output/indel/rtg_vcfeval/lc/9bp/A153_01.indel/truth.vcf.gz.tbi
[W::hts_idx_load3] The index file is older than the data file: /scratch/prj/ppn_als_longread/ont-benchmark/output/indel/rtg_vcfeval/lc/40bp/A153_01.indel/query.vcf.gz.tbi
[W::hts_idx_load3] The index file is older than the data file: /scratch/prj/ppn_als_longread/ont-benchmark/output/indel/rtg_vcfeval/hc/17bp/A153_06.indel/truth.vcf.gz.tbi
[W::hts_idx_load3] The index file is older than the data file: /scratch/prj/ppn_als_longread/ont-benchmark/output/indel/rtg_vcfeval/hc/6bp/A153_06.indel/query.vcf.gz.tbi
[W::hts_idx_load3] The index file is older than the data file: /scratch/prj/ppn_als_longread/ont-benchmark/output/indel/rtg_vcfeval/hc/6bp/A153_06.indel/truth.vcf.gz.tbi
[W::hts_idx_load3] The index file is older than the data file: /scratch/prj/ppn_als_longread/ont-benchmark/output/indel/rtg_vcfeval/lc/32bp/A153_06.indel/query.vcf.gz.tbi
[W::hts_idx_load3] The index file is older than the data file: /scratch/prj/ppn_als_longread/ont-benchmark/output/indel/rtg_vcfeval/lc/24bp/A153_06.indel/query.vcf.gz.tbi
[W::hts_idx_load3] The index file is older than the data file: /scratch/prj/ppn_als_longread/ont-benchmark/output/indel/rtg_vcfeval/lc/24bp/A153_06.indel/truth.vcf.gz.tbi
[W::hts_idx_load3] The index file is older than the data file: /scratch/prj/ppn_als_longread/ont-benchmark/output/indel/rtg_vcfeval/lc/34bp/A153_06.indel/query.vcf.gz.tbi
[W::hts_idx_load3] The index file is older than the data file: /scratch/prj/ppn_als_longread/ont-benchmark/output/indel/rtg_vcfeval/lc/2bp/A153_06.indel/truth.vcf.gz.tbi
[W::hts_idx_load3] The index file is older than the data file: /scratch/prj/ppn_als_longread/ont-benchmark/output/indel/rtg_vcfeval/lc/6bp/A153_06.indel/query.vcf.gz.tbi
[W::hts_idx_load3] The index file is older than the data file: /scratch/prj/ppn_als_longread/ont-benchmark/output/indel/rtg_vcfeval/hc/9bp/A154_04.indel/query.vcf.gz.tbi
[W::hts_idx_load3] The index file is older than the data file: /scratch/prj/ppn_als_longread/ont-benchmark/output/indel/rtg_vcfeval/hc/18bp/A154_04.indel/query.vcf.gz.tbi
[W::hts_idx_load3] The index file is older than the data file: /scratch/prj/ppn_als_longread/ont-benchmark/output/indel/rtg_vcfeval/hc/45bp/A154_04.indel/query.vcf.gz.tbi
[W::hts_idx_load3] The index file is older than the data file: /scratch/prj/ppn_als_longread/ont-benchmark/output/indel/rtg_vcfeval/hc/45bp/A154_04.indel/truth.vcf.gz.tbi
[W::hts_idx_load3] The index file is older than the data file: /scratch/prj/ppn_als_longread/ont-benchmark/output/indel/rtg_vcfeval/hc/22bp/A154_04.indel/query.vcf.gz.tbi
[W::hts_idx_load3] The index file is older than the data file: /scratch/prj/ppn_als_longread/ont-benchmark/output/indel/rtg_vcfeval/hc/16bp/A154_04.indel/truth.vcf.gz.tbi
[W::hts_idx_load3] The index file is older than the data file: /scratch/prj/ppn_als_longread/ont-benchmark/output/indel/rtg_vcfeval/hc/23bp/A154_04.indel/truth.vcf.gz.tbi
[W::hts_idx_load3] The index file is older than the data file: /scratch/prj/ppn_als_longread/ont-benchmark/output/indel/rtg_vcfeval/hc/10bp/A154_04.indel/query.vcf.gz.tbi
[W::hts_idx_load3] The index file is older than the data file: /scratch/prj/ppn_als_longread/ont-benchmark/output/indel/rtg_vcfeval/hc/10bp/A154_04.indel/truth.vcf.gz.tbi
[W::hts_idx_load3] The index file is older than the data file: /scratch/prj/ppn_als_longread/ont-benchmark/output/indel/rtg_vcfeval/lc/25bp/A154_04.indel/truth.vcf.gz.tbi
[W::hts_idx_load3] The index file is older than the data file: /scratch/prj/ppn_als_longread/ont-benchmark/output/indel/rtg_vcfeval/lc/14bp/A154_04.indel/truth.vcf.gz.tbi
[W::hts_idx_load3] The index file is older than the data file: /scratch/prj/ppn_als_longread/ont-benchmark/output/indel/rtg_vcfeval/lc/43bp/A154_04.indel/query.vcf.gz.tbi
[W::hts_idx_load3] The index file is older than the data file: /scratch/prj/ppn_als_longread/ont-benchmark/output/indel/rtg_vcfeval/lc/1bp/A154_04.indel/truth.vcf.gz.tbi
[W::hts_idx_load3] The index file is older than the data file: /scratch/prj/ppn_als_longread/ont-benchmark/output/indel/rtg_vcfeval/hc/32bp/A154_06.indel/query.vcf.gz.tbi
[W::hts_idx_load3] The index file is older than the data file: /scratch/prj/ppn_als_longread/ont-benchmark/output/indel/rtg_vcfeval/hc/34bp/A154_06.indel/truth.vcf.gz.tbi
[W::hts_idx_load3] The index file is older than the data file: /scratch/prj/ppn_als_longread/ont-benchmark/output/indel/rtg_vcfeval/hc/27bp/A154_06.indel/truth.vcf.gz.tbi
[W::hts_idx_load3] The index file is older than the data file: /scratch/prj/ppn_als_longread/ont-benchmark/output/indel/rtg_vcfeval/hc/1bp/A154_06.indel/truth.vcf.gz.tbi
[W::hts_idx_load3] The index file is older than the data file: /scratch/prj/ppn_als_longread/ont-benchmark/output/indel/rtg_vcfeval/lc/49bp/A154_06.indel/truth.vcf.gz.tbi
[W::hts_idx_load3] The index file is older than the data file: /scratch/prj/ppn_als_longread/ont-benchmark/output/indel/rtg_vcfeval/lc/30bp/A154_06.indel/truth.vcf.gz.tbi
[W::hts_idx_load3] The index file is older than the data file: /scratch/prj/ppn_als_longread/ont-benchmark/output/indel/rtg_vcfeval/lc/23bp/A154_06.indel/truth.vcf.gz.tbi
[W::hts_idx_load3] The index file is older than the data file: /scratch/prj/ppn_als_longread/ont-benchmark/output/indel/rtg_vcfeval/hc/1bp/A157_02.indel/truth.vcf.gz.tbi
[W::hts_idx_load3] The index file is older than the data file: /scratch/prj/ppn_als_longread/ont-benchmark/output/indel/rtg_vcfeval/lc/34bp/A157_02.indel/truth.vcf.gz.tbi
[W::hts_idx_load3] The index file is older than the data file: /scratch/prj/ppn_als_longread/ont-benchmark/output/indel/rtg_vcfeval/lc/11bp/A157_02.indel/query.vcf.gz.tbi
[W::hts_idx_load3] The index file is older than the data file: /scratch/prj/ppn_als_longread/ont-benchmark/output/indel/rtg_vcfeval/lc/40bp/A157_02.indel/truth.vcf.gz.tbi
[W::hts_idx_load3] The index file is older than the data file: /scratch/prj/ppn_als_longread/ont-benchmark/output/indel/rtg_vcfeval/lc/19bp/A157_02.indel/truth.vcf.gz.tbi
[W::hts_idx_load3] The index file is older than the data file: /scratch/prj/ppn_als_longread/ont-benchmark/output/indel/rtg_vcfeval/hc/3bp/A160_96.indel/query.vcf.gz.tbi
[W::hts_idx_load3] The index file is older than the data file: /scratch/prj/ppn_als_longread/ont-benchmark/output/indel/rtg_vcfeval/hc/3bp/A160_96.indel/truth.vcf.gz.tbi
[W::hts_idx_load3] The index file is older than the data file: /scratch/prj/ppn_als_longread/ont-benchmark/output/indel/rtg_vcfeval/hc/6bp/A160_96.indel/truth.vcf.gz.tbi
[W::hts_idx_load3] The index file is older than the data file: /scratch/prj/ppn_als_longread/ont-benchmark/output/indel/rtg_vcfeval/lc/33bp/A160_96.indel/truth.vcf.gz.tbi
[W::hts_idx_load3] The index file is older than the data file: /scratch/prj/ppn_als_longread/ont-benchmark/output/indel/rtg_vcfeval/lc/39bp/A160_96.indel/truth.vcf.gz.tbi
[W::hts_idx_load3] The index file is older than the data file: /scratch/prj/ppn_als_longread/ont-benchmark/output/indel/rtg_vcfeval/lc/45bp/A160_96.indel/query.vcf.gz.tbi
[W::hts_idx_load3] The index file is older than the data file: /scratch/prj/ppn_als_longread/ont-benchmark/output/indel/rtg_vcfeval/lc/12bp/A160_96.indel/query.vcf.gz.tbi
[W::hts_idx_load3] The index file is older than the data file: /scratch/prj/ppn_als_longread/ont-benchmark/output/indel/rtg_vcfeval/hc/27bp/A162_09.indel/query.vcf.gz.tbi
[W::hts_idx_load3] The index file is older than the data file: /scratch/prj/ppn_als_longread/ont-benchmark/output/indel/rtg_vcfeval/hc/47bp/A162_09.indel/truth.vcf.gz.tbi
[W::hts_idx_load3] The index file is older than the data file: /scratch/prj/ppn_als_longread/ont-benchmark/output/indel/rtg_vcfeval/hc/1bp/A162_09.indel/query.vcf.gz.tbi
[W::hts_idx_load3] The index file is older than the data file: /scratch/prj/ppn_als_longread/ont-benchmark/output/indel/rtg_vcfeval/lc/35bp/A162_09.indel/truth.vcf.gz.tbi
[W::hts_idx_load3] The index file is older than the data file: /scratch/prj/ppn_als_longread/ont-benchmark/output/indel/rtg_vcfeval/lc/41bp/A162_09.indel/truth.vcf.gz.tbi
[W::hts_idx_load3] The index file is older than the data file: /scratch/prj/ppn_als_longread/ont-benchmark/output/indel/rtg_vcfeval/lc/27bp/A162_09.indel/truth.vcf.gz.tbi
[W::hts_idx_load3] The index file is older than the data file: /scratch/prj/ppn_als_longread/ont-benchmark/output/indel/rtg_vcfeval/lc/6bp/A162_09.indel/truth.vcf.gz.tbi
[W::hts_idx_load3] The index file is older than the data file: /scratch/prj/ppn_als_longread/ont-benchmark/output/indel/rtg_vcfeval/lc/23bp/A162_09.indel/query.vcf.gz.tbi
[W::hts_idx_load3] The index file is older than the data file: /scratch/prj/ppn_als_longread/ont-benchmark/output/indel/rtg_vcfeval/lc/1bp/A162_09.indel/truth.vcf.gz.tbi
Statistical Test Results (Kolmogorov-Smirnov test with FDR correction):
shape: (4, 5)
ComplexityIndel TypeKS Statisticp-valueAdjusted p-value
strstrf64f64f64
"High Complexity""Insertion"0.0127162.1192e-1162.1192e-116
"High Complexity""Deletion"0.0128522.3141e-1183.0854e-118
"Low Complexity""Insertion"0.0277550.00.0
"Low Complexity""Deletion"0.0116572.3832e-3184.7664e-318
No description has been provided for this image

4. Combined Plots¶

In [32]:
def create_combined_indel_metrics_plot(
    indel_ont_stats: pl.DataFrame,
    distribution_data: pl.DataFrame,
    indel_size_plot_data: pl.DataFrame,
    regression_results: Dict,
    metrics: Tuple[str, ...] = ("Precision", "Sensitivity", "F-measure"),
    complexities: Tuple[str, ...] = ("hc", "lc"),
    figsize: Tuple[int, int] = (12, 16),
    dpi: int = 300,
) -> plt.Figure:
    """
    Create a combined figure showing indel performance metrics, size distributions,
    and size-based performance analysis.

    Args:
        indel_ont_stats: DataFrame containing ONT indel statistics
        distribution_data: DataFrame containing indel size distribution data
        indel_size_plot_data: DataFrame containing size-based performance data
        regression_results: Dictionary containing regression analysis results
        metrics: Tuple of performance metrics to plot
        complexities: Tuple of complexity regions to analyze
        figsize: Figure dimensions (width, height)
        dpi: Figure resolution

    Returns:
        Combined figure object containing all plots

    Raises:
        ValueError: If input data is invalid or missing
        Exception: If there's an error creating the combined plot
    """
    try:
        # Input validation
        if indel_ont_stats.height == 0 or distribution_data.height == 0:
            raise ValueError("Input DataFrames cannot be empty")

        # Create figure with GridSpec
        fig = plt.figure(figsize=figsize, dpi=dpi)
        gs = fig.add_gridspec(4, 2, height_ratios=[0.3, 0.3, 0.3, 1.3])

        # Section A & B: Performance Metrics
        performance_data = prepare_indel_performance_data(
            ont_stats=indel_ont_stats, metrics=metrics, complexities=complexities
        )
        gs_perf = gridspec.GridSpecFromSubplotSpec(1, 2, subplot_spec=gs[0, :])
        plot_indel_performance_metrics(plot_data_df=performance_data, gs=gs_perf)

        # Add panel labels for A & B
        for i, ax in enumerate(fig.axes[:2]):
            label = chr(ord("A") + i)
            ax.text(
                -0.1,
                1.05,
                label,
                transform=ax.transAxes,
                fontsize=12,
                fontweight="bold",
            )

        # Section C-F: Size Distributions
        gs_dist = gridspec.GridSpecFromSubplotSpec(2, 2, subplot_spec=gs[1:3, :])
        plot_indel_size_distributions(df=distribution_data, gs=gs_dist)

        # Add panel labels for C-F
        for i, ax in enumerate(fig.axes[2:6]):
            label = chr(ord("C") + i)
            ax.text(
                -0.1,
                1.05,
                label,
                transform=ax.transAxes,
                fontsize=12,
                fontweight="bold",
            )

        # Section G-I: Size Performance Analysis
        gs_size = gridspec.GridSpecFromSubplotSpec(3, 1, subplot_spec=gs[3, :])
        plot_indel_size_performance(
            plot_data=indel_size_plot_data,
            regression_results=regression_results,
            gs=gs_size,
        )

        # Add panel labels for G-I
        for i, ax in enumerate(fig.axes[6:]):
            label = chr(ord("G") + i)
            ax.text(
                -0.045,
                1.05,
                label,
                transform=ax.transAxes,
                fontsize=12,
                fontweight="bold",
            )

        # Adjust layout
        fig.set_constrained_layout(True)
        return fig

    except Exception as e:
        logger.error(f"Error creating combined indel metrics plot: {str(e)}")
        raise


combined_indel_fig = create_combined_indel_metrics_plot(
    indel_ont_stats=indel_ont_stats,
    distribution_data=distribution_data,
    indel_size_plot_data=indel_size_plot_data,
    regression_results=indel_size_perf_regression_results,
)
matplotlib.category - INFO - Using categorical units to plot a list of strings that are all parsable as floats or dates. If these strings should be plotted as numbers, cast to the appropriate data type before plotting.
matplotlib.category - INFO - Using categorical units to plot a list of strings that are all parsable as floats or dates. If these strings should be plotted as numbers, cast to the appropriate data type before plotting.
matplotlib.category - INFO - Using categorical units to plot a list of strings that are all parsable as floats or dates. If these strings should be plotted as numbers, cast to the appropriate data type before plotting.
matplotlib.category - INFO - Using categorical units to plot a list of strings that are all parsable as floats or dates. If these strings should be plotted as numbers, cast to the appropriate data type before plotting.
matplotlib.category - INFO - Using categorical units to plot a list of strings that are all parsable as floats or dates. If these strings should be plotted as numbers, cast to the appropriate data type before plotting.
matplotlib.category - INFO - Using categorical units to plot a list of strings that are all parsable as floats or dates. If these strings should be plotted as numbers, cast to the appropriate data type before plotting.
No description has been provided for this image

Impacts of sequencing on SNV and Indel variant calling¶

1. Impact of multiplexing on variant calling¶

In [33]:
def create_snv_multiplexing_comparison_plot(
    metrics_df: pl.DataFrame,
    multiplexing_df: pl.DataFrame,
    config: SNVAnalysisConfig,
    figsize: Tuple[int, int] = (14, 8),
    dpi: int = 300,
    gs: Optional[gridspec.GridSpec] = None,
) -> Optional[plt.Figure]:
    """
    Create violin plots comparing SNV performance metrics between multiplexed and
    singleplexed samples.

    Args:
        metrics_df: DataFrame containing performance metrics
        multiplexing_df: DataFrame containing multiplexing information
        figsize: Figure size in inches
        dpi: Figure resolution
        metrics: Performance metrics to plot
        complexities: Complexity levels to compare
        gs: GridSpec for subplot placement. If None, creates standalone figure

    Returns:
        Optional[plt.Figure]: If gs is None, returns the figure.
        If gs is provided, returns None (plots are added as subfigures).

    Raises:
        ValueError: If required columns are missing from input DataFrames
    """
    try:
        if gs is None:
            fig, axes = plt.subplots(2, 3, figsize=figsize, dpi=dpi)
        else:
            fig = plt.gcf()
            axes = np.array(
                [[plt.subplot(gs[i, j]) for j in range(3)] for i in range(2)]
            )

        fig.suptitle(
            "SNV Performance Metrics vs Multiplexing",
        )

        merged_df = metrics_df.join(
            multiplexing_df.select(["sample", "multiplexing"]),
            left_on="sample_id",
            right_on="sample",
            how="inner",
        )

        for row, complexity in enumerate(config.complexities):
            complexity_data = merged_df.filter(pl.col("complexity") == complexity)

            for col, metric in enumerate(config.metrics_to_test):
                ax = axes[row, col]

                sns.violinplot(
                    x="multiplexing",
                    y=metric,
                    data=complexity_data,
                    ax=ax,
                    hue="multiplexing",
                )

                ax.set_xlabel("Multiplexing")
                ax.set_ylabel(metric.capitalize().replace("_", "-"))

            y_mins = []
            y_maxs = []
            for ax in axes[row, :]:
                y_mins.append(ax.get_ylim()[0])
                y_maxs.append(ax.get_ylim()[1])

            y_min = min(y_mins)
            y_max = max(y_maxs)
            for ax in axes[row, :]:
                ax.set_ylim(y_min, y_max)

            axes[row, 0].annotate(
                "High Complexity" if row == 0 else "Low Complexity",
                xy=(-0.17, 0.5),
                xycoords="axes fraction",
                fontweight="bold",
                ha="right",
                va="center",
                rotation=90,
            )

        if gs is None:
            plt.tight_layout()
            return fig
        return None

    except Exception as e:
        logger.error(f"Error creating performance comparison plots: {str(e)}")
        raise


performance_comparison_fig = create_snv_multiplexing_comparison_plot(
    metrics_df=snv_rtg_metrics_dfs["ont"],
    multiplexing_df=nanoplot_qc_metrics_df,
    config=snv_config,
)
No description has been provided for this image
In [34]:
def create_indel_multiplexing_comparison_plot(
    metrics_df: pl.DataFrame,
    multiplexing_df: pl.DataFrame,
    config: IndelAnalysisConfig,
    figsize: Tuple[int, int] = (14, 8),
    dpi: int = 300,
    gs: Optional[gridspec.GridSpec] = None,
) -> Optional[plt.Figure]:
    """
    Create violin plots comparing indel performance metrics between multiplexed and
    singleplexed samples.

    Args:
        metrics_df: DataFrame containing indel performance metrics
        multiplexing_df: DataFrame containing multiplexing information
        config: Analysis configuration
        figsize: Figure size in inches
        dpi: Figure resolution
        gs: GridSpec for subplot placement. If None, creates standalone figure

    Returns:
        Optional[plt.Figure]: If gs is None, returns the figure.
        If gs is provided, returns None (plots are added as subfigures).

    Raises:
        ValueError: If required columns are missing from input DataFrames
    """
    try:
        if gs is None:
            fig, axes = plt.subplots(2, 3, figsize=figsize, dpi=dpi)
        else:
            fig = plt.gcf()
            axes = np.array(
                [[plt.subplot(gs[i, j]) for j in range(3)] for i in range(2)]
            )

        fig.suptitle(
            "Indel Performance Metrics vs Multiplexing",
        )

        merged_df = metrics_df.join(
            multiplexing_df.select(["sample", "multiplexing"]),
            left_on="sample_id",
            right_on="sample",
            how="inner",
        )

        for row, complexity in enumerate(config.complexities):
            complexity_data = merged_df.filter(pl.col("complexity") == complexity)

            for col, metric in enumerate(config.metrics_to_test):
                ax = axes[row, col]

                sns.violinplot(
                    x="multiplexing",
                    y=metric,
                    data=complexity_data,
                    ax=ax,
                    hue="multiplexing",
                )

                ax.set_xlabel("Multiplexing")
                ax.set_ylabel(metric.capitalize().replace("_", "-"))

            y_mins = []
            y_maxs = []
            for ax in axes[row, :]:
                y_mins.append(ax.get_ylim()[0])
                y_maxs.append(ax.get_ylim()[1])

            y_min = min(y_mins)
            y_max = max(y_maxs)
            for ax in axes[row, :]:
                ax.set_ylim(y_min, y_max)

            axes[row, 0].annotate(
                "High Complexity" if row == 0 else "Low Complexity",
                xy=(-0.17, 0.5),
                xycoords="axes fraction",
                fontweight="bold",
                ha="right",
                va="center",
                rotation=90,
            )

        if gs is None:
            plt.tight_layout()
            return fig
        return None

    except Exception as e:
        logger.error(f"Error creating indel performance comparison plots: {str(e)}")
        raise


indel_performance_comparison_fig = create_indel_multiplexing_comparison_plot(
    metrics_df=indel_rtg_metrics_dfs["ont"],
    multiplexing_df=nanoplot_qc_metrics_df,
    config=indel_config,
)
No description has been provided for this image
In [35]:
def create_combined_multiplexing_comparison_plot(
    snv_metrics_df: pl.DataFrame,
    indel_metrics_df: pl.DataFrame,
    multiplexing_df: pl.DataFrame,
    snv_config: SNVAnalysisConfig,
    indel_config: IndelAnalysisConfig,
    figsize: Tuple[int, int] = (14, 16),
    dpi: int = 300,
) -> plt.Figure:
    """
    Create a combined figure showing both SNV and indel performance metrics
    comparisons between multiplexed and singleplexed samples.

    Args:
        snv_metrics_df: DataFrame containing SNV performance metrics
        indel_metrics_df: DataFrame containing indel performance metrics
        multiplexing_df: DataFrame containing multiplexing information
        snv_config: Configuration for SNV analysis
        indel_config: Configuration for indel analysis
        figsize: Figure dimensions (width, height)
        dpi: Figure resolution

    Returns:
        Combined figure object containing all plots

    Raises:
        ValueError: If input data is invalid or missing
        Exception: If there's an error creating the combined plot
    """
    try:
        # Input validation
        if snv_metrics_df.height == 0 or indel_metrics_df.height == 0:
            raise ValueError("Input DataFrames cannot be empty")

        # Create figure with GridSpec
        fig = plt.figure(figsize=figsize, dpi=dpi)
        gs = fig.add_gridspec(
            3, 1, height_ratios=[1, 0.03, 1]
        )  # Add middle spacing row

        # SNV section
        gs_snv = gridspec.GridSpecFromSubplotSpec(2, 3, subplot_spec=gs[0])
        create_snv_multiplexing_comparison_plot(
            metrics_df=snv_metrics_df,
            multiplexing_df=multiplexing_df,
            config=snv_config,
            gs=gs_snv,
        )

        # Add SNV section title
        fig.text(
            0.5,
            0.99,
            "SNV Performance Metrics vs Multiplexing",
            ha="center",
            fontsize=12,
            fontweight="bold",
        )

        # Indel section
        gs_indel = gridspec.GridSpecFromSubplotSpec(
            2, 3, subplot_spec=gs[2]
        )  # Move Indel to row 3 for padding
        create_indel_multiplexing_comparison_plot(
            metrics_df=indel_metrics_df,
            multiplexing_df=multiplexing_df,
            config=indel_config,
            gs=gs_indel,
        )

        # Add Indel section title
        fig.text(
            0.5,
            0.48,
            "Indel Performance Metrics vs Multiplexing",
            ha="center",
            fontsize=12,
            fontweight="bold",
        )

        # Add panel labels (one per row)
        for i, ax in enumerate(
            fig.axes[::3]
        ):  # Step by 3 to label the first plot in each row
            label = chr(ord("A") + i)
            ax.text(
                -0.1,
                1.05,
                label,
                transform=ax.transAxes,
                fontsize=12,
                fontweight="bold",
            )

        # Adjust layout
        fig.suptitle("")
        fig.set_constrained_layout(True)
        return fig

    except Exception as e:
        logger.error(f"Error creating combined variant comparison plot: {str(e)}")
        raise


combined_variant_fig = create_combined_multiplexing_comparison_plot(
    snv_metrics_df=snv_rtg_metrics_dfs["ont"],
    indel_metrics_df=indel_rtg_metrics_dfs["ont"],
    multiplexing_df=nanoplot_qc_metrics_df,
    snv_config=snv_config,
    indel_config=indel_config,
)
No description has been provided for this image
In [36]:
@dataclass
class MetricStats:
    """Class for storing statistical metrics.

    Attributes:
        mean (float): The mean value of the metric
        std (float): The standard deviation of the metric
        median (float): The median value of the metric
    """

    mean: float
    std: float
    median: float


@dataclass
class VariantStats:
    """Class for storing variant statistics.

    Attributes:
        precision (MetricStats): Precision metrics including mean, std, and median
        sensitivity (MetricStats): Sensitivity metrics including mean, std, and median
        f_measure (MetricStats): F-measure metrics including mean, std, and median
        multiplexing (str): Type of multiplexing (singleplex or multiplex)
        variant_type (str): Type of variant (SNV or Indel)
        complexity (str): Complexity level (hc or lc)
    """

    precision: MetricStats
    sensitivity: MetricStats
    f_measure: MetricStats
    multiplexing: str
    variant_type: str
    complexity: str


def create_variant_multiplexing_stats(
    snv_df: pl.DataFrame, indel_df: pl.DataFrame, multiplexing_df: pl.DataFrame
) -> List[VariantStats]:
    """Creates variant multiplexing statistics by combining SNV, Indel, and multiplexing data.

    Args:
        snv_df (pl.DataFrame): DataFrame containing SNV metrics
        indel_df (pl.DataFrame): DataFrame containing Indel metrics
        multiplexing_df (pl.DataFrame): DataFrame containing multiplexing information

    Returns:
        List[VariantStats]: List of VariantStats objects containing calculated statistics
        for different combinations of variant types and complexity levels

    Raises:
        ValueError: If required columns are missing in the input DataFrames
    """

    def calculate_stats(
        df: pl.DataFrame, variant_type: str, complexity: str
    ) -> List[VariantStats]:
        """Calculates statistics for a specific variant type and complexity level.

        Args:
            df (pl.DataFrame): DataFrame containing variant metrics
            variant_type (str): Type of variant (SNV or Indel)
            complexity (str): Complexity level (hc or lc)

        Returns:
            List[VariantStats]: List of VariantStats objects for the specified variant type
            and complexity level
        """
        merged_df = df.filter(pl.col("complexity") == complexity).join(
            multiplexing_df.select(["sample", "multiplexing"]),
            left_on="sample_id",
            right_on="sample",
            how="inner",
        )

        stats_df = merged_df.group_by("multiplexing").agg(
            [
                pl.col("precision").mean().alias("Precision_mean"),
                pl.col("precision").std().alias("Precision_std"),
                pl.col("precision").median().alias("Precision_median"),
                pl.col("sensitivity").mean().alias("Sensitivity_mean"),
                pl.col("sensitivity").std().alias("Sensitivity_std"),
                pl.col("sensitivity").median().alias("Sensitivity_median"),
                pl.col("f_measure").mean().alias("F-measure_mean"),
                pl.col("f_measure").std().alias("F-measure_std"),
                pl.col("f_measure").median().alias("F-measure_median"),
            ]
        )

        return [
            VariantStats(
                precision=MetricStats(
                    mean=row["Precision_mean"],
                    std=row["Precision_std"],
                    median=row["Precision_median"],
                ),
                sensitivity=MetricStats(
                    mean=row["Sensitivity_mean"],
                    std=row["Sensitivity_std"],
                    median=row["Sensitivity_median"],
                ),
                f_measure=MetricStats(
                    mean=row["F-measure_mean"],
                    std=row["F-measure_std"],
                    median=row["F-measure_median"],
                ),
                multiplexing=row["multiplexing"],
                variant_type=variant_type,
                complexity=complexity,
            )
            for row in stats_df.to_dicts()
        ]

    stats_list = []
    for variant_type, df in [("SNV", snv_df), ("Indel", indel_df)]:
        for complexity in ["hc", "lc"]:
            stats_list.extend(calculate_stats(df, variant_type, complexity))

    return stats_list


def format_number_decimals(num: float) -> str:
    """Formats a number to three decimal places.

    Args:
        num (float): Number to format

    Returns:
        str: Formatted string with three decimal places
    """
    return f"{num:.3f}"


def summarize_variant_stats(stats_list: List[VariantStats]) -> None:
    """Summarizes and prints variant statistics comparing singleplex and multiplex results.

    Args:
        stats_list (List[VariantStats]): List of VariantStats objects containing
            calculated statistics

    Prints:
        Formatted summary of statistics including precision, sensitivity, and F-measure
        for both singleplex and multiplex variants, along with percentage increases
    """
    for variant_type in ["SNV", "Indel"]:
        for complexity in ["hc", "lc"]:
            print(f"\n{variant_type} Statistics ({complexity.upper()}):")
            print("=" * 40)

            variant_stats = [
                stat
                for stat in stats_list
                if stat.variant_type == variant_type and stat.complexity == complexity
            ]

            singleplex_stats = next(
                stat for stat in variant_stats if stat.multiplexing == "singleplex"
            )
            multiplex_stats = next(
                stat for stat in variant_stats if stat.multiplexing == "multiplex"
            )

            for metric_name, metric_pair in [
                ("Precision", (singleplex_stats.precision, multiplex_stats.precision)),
                (
                    "Sensitivity",
                    (singleplex_stats.sensitivity, multiplex_stats.sensitivity),
                ),
                ("F-measure", (singleplex_stats.f_measure, multiplex_stats.f_measure)),
            ]:
                singleplex_metric, multiplex_metric = metric_pair
                print(f"\n{metric_name}:")

                for stat_name in ["mean", "std", "median"]:
                    singleplex_val = getattr(singleplex_metric, stat_name)
                    multiplex_val = getattr(multiplex_metric, stat_name)
                    print(
                        f"  {stat_name.capitalize():6s}: "
                        f"Singleplex: {format_number_decimals(singleplex_val)}, "
                        f"Multiplex: {format_number_decimals(multiplex_val)}"
                    )

                increase = _calculate_percentage_increase(
                    singleplex_metric.mean, multiplex_metric.mean
                )
                print(
                    f"  Mean Percentage Increase (Singleplex vs Multiplex): "
                    f"{increase:6.2f}%"
                )


variant_multiplexing_stats = create_variant_multiplexing_stats(
    snv_rtg_metrics_dfs["ont"], indel_rtg_metrics_dfs["ont"], nanoplot_qc_metrics_df
)

summarize_variant_stats(variant_multiplexing_stats)
SNV Statistics (HC):
========================================

Precision:
  Mean  : Singleplex: 0.960, Multiplex: 0.944
  Std   : Singleplex: 0.003, Multiplex: 0.008
  Median: Singleplex: 0.961, Multiplex: 0.947
  Mean Percentage Increase (Singleplex vs Multiplex):   1.68%

Sensitivity:
  Mean  : Singleplex: 0.970, Multiplex: 0.936
  Std   : Singleplex: 0.004, Multiplex: 0.018
  Median: Singleplex: 0.970, Multiplex: 0.941
  Mean Percentage Increase (Singleplex vs Multiplex):   3.67%

F-measure:
  Mean  : Singleplex: 0.965, Multiplex: 0.940
  Std   : Singleplex: 0.003, Multiplex: 0.013
  Median: Singleplex: 0.965, Multiplex: 0.945
  Mean Percentage Increase (Singleplex vs Multiplex):   2.67%

SNV Statistics (LC):
========================================

Precision:
  Mean  : Singleplex: 0.788, Multiplex: 0.765
  Std   : Singleplex: 0.005, Multiplex: 0.009
  Median: Singleplex: 0.789, Multiplex: 0.767
  Mean Percentage Increase (Singleplex vs Multiplex):   2.92%

Sensitivity:
  Mean  : Singleplex: 0.747, Multiplex: 0.717
  Std   : Singleplex: 0.005, Multiplex: 0.015
  Median: Singleplex: 0.747, Multiplex: 0.722
  Mean Percentage Increase (Singleplex vs Multiplex):   4.27%

F-measure:
  Mean  : Singleplex: 0.767, Multiplex: 0.740
  Std   : Singleplex: 0.005, Multiplex: 0.012
  Median: Singleplex: 0.767, Multiplex: 0.744
  Mean Percentage Increase (Singleplex vs Multiplex):   3.62%

Indel Statistics (HC):
========================================

Precision:
  Mean  : Singleplex: 0.836, Multiplex: 0.720
  Std   : Singleplex: 0.028, Multiplex: 0.035
  Median: Singleplex: 0.829, Multiplex: 0.734
  Mean Percentage Increase (Singleplex vs Multiplex):  16.13%

Sensitivity:
  Mean  : Singleplex: 0.930, Multiplex: 0.873
  Std   : Singleplex: 0.009, Multiplex: 0.030
  Median: Singleplex: 0.928, Multiplex: 0.882
  Mean Percentage Increase (Singleplex vs Multiplex):   6.52%

F-measure:
  Mean  : Singleplex: 0.880, Multiplex: 0.789
  Std   : Singleplex: 0.019, Multiplex: 0.033
  Median: Singleplex: 0.876, Multiplex: 0.801
  Mean Percentage Increase (Singleplex vs Multiplex):  11.57%

Indel Statistics (LC):
========================================

Precision:
  Mean  : Singleplex: 0.458, Multiplex: 0.310
  Std   : Singleplex: 0.074, Multiplex: 0.019
  Median: Singleplex: 0.425, Multiplex: 0.313
  Mean Percentage Increase (Singleplex vs Multiplex):  47.53%

Sensitivity:
  Mean  : Singleplex: 0.490, Multiplex: 0.424
  Std   : Singleplex: 0.019, Multiplex: 0.024
  Median: Singleplex: 0.483, Multiplex: 0.429
  Mean Percentage Increase (Singleplex vs Multiplex):  15.66%

F-measure:
  Mean  : Singleplex: 0.472, Multiplex: 0.358
  Std   : Singleplex: 0.048, Multiplex: 0.021
  Median: Singleplex: 0.452, Multiplex: 0.362
  Mean Percentage Increase (Singleplex vs Multiplex):  31.70%

2. Impact of sequencing depth on variant calling¶

In [37]:
@dataclass
class PerformanceCorrelation:
    """Data class for storing performance correlation results."""

    correlation: float
    p_value: float
    fit_params: Tuple[float, float, float]
    confidence_intervals: np.ndarray


def asymptotic_func(x: np.ndarray, a: float, b: float, c: float) -> np.ndarray:
    """Calculate asymptotic function.

    Args:
        x: Input array
        a: First parameter
        b: Second parameter
        c: Third parameter

    Returns:
        Calculated asymptotic values
    """
    return a - b * np.exp(-c * x)


def calculate_correlation_stats(x: np.ndarray, y: np.ndarray) -> PerformanceCorrelation:
    """Calculate correlation statistics between two variables.

    Args:
        x: Independent variable array
        y: Dependent variable array

    Returns:
        PerformanceCorrelation object containing correlation statistics

    Raises:
        ValueError: If curve fitting fails
    """
    try:
        # Calculate Pearson correlation
        correlation, p_value = stats.pearsonr(x, y)

        # Fit asymptotic curve
        popt, pcov = curve_fit(
            asymptotic_func, x, y, p0=[1, 0.1, 0.1], bounds=([0, 0, 0], [2, 1, 1])
        )

        # Calculate confidence intervals
        perr = np.sqrt(np.diag(pcov))
        n = len(x)
        dof = max(0, n - len(popt))
        t = stats.t.ppf(0.975, dof)
        y_err = np.sqrt(np.sum((y - asymptotic_func(x, *popt)) ** 2) / dof)

        x_range = np.linspace(x.min(), x.max(), 100)
        ci = (
            t
            * y_err
            * np.sqrt(
                1 / n + (x_range - np.mean(x)) ** 2 / np.sum((x - np.mean(x)) ** 2)
            )
        )

        return PerformanceCorrelation(
            correlation=correlation,
            p_value=p_value,
            fit_params=tuple(popt),
            confidence_intervals=ci,
        )

    except Exception as e:
        raise ValueError(f"Error calculating correlation statistics: {str(e)}")


def plot_depth_vs_performance(
    depth_df: pl.DataFrame,
    snv_metrics_df: pl.DataFrame,
    indel_metrics_df: pl.DataFrame,
    metrics_df: pl.DataFrame,
    figsize: Tuple[int, int] = (14, 16),
    dpi: int = 300,
) -> plt.Figure:
    """Create plots comparing variant calling performance metrics against sequencing depth.

    Args:
        depth_df: DataFrame containing whole genome depth statistics
        snv_metrics_df: DataFrame containing SNV calling metrics
        indel_metrics_df: DataFrame containing indel calling metrics
        metrics_df: DataFrame containing sample metadata
        figsize: Figure size in inches
        dpi: Figure resolution

    Returns:
        plt.Figure: Figure object containing the plots

    Raises:
        ValueError: If required columns are missing
    """
    try:
        metrics = ["precision", "sensitivity", "f_measure"]
        variant_types = ["SNV", "Indel"]
        complexities = ["hc", "lc"]

        fig = plt.figure(figsize=figsize, dpi=dpi)
        # Add height_ratios for padding between sections
        gs = gridspec.GridSpec(5, 1, height_ratios=[1, 1, 0.05, 1, 1])
        row_positions = {}

        # Add section titles
        fig.text(
            0.5,
            0.99,
            "SNV Performance Metrics vs Whole Genome Mean Depth",
            ha="center",
            fontsize=12,
            fontweight="bold",
        )
        fig.text(
            0.5,
            0.468,
            "Indel Performance Metrics vs Whole Genome Mean Depth",
            ha="center",
            fontsize=12,
            fontweight="bold",
        )

        for i, complexity in enumerate(complexities):
            y_limits = {metric: (float("inf"), -float("inf")) for metric in metrics}

            for j, (variant_type, variant_metrics_df) in enumerate(
                zip(variant_types, [snv_metrics_df, indel_metrics_df])
            ):
                # Adjust row index to account for padding
                main_row_index = j * 3 + i if j == 1 else i
                inner_gs = gridspec.GridSpecFromSubplotSpec(
                    1, 3, subplot_spec=gs[main_row_index]
                )

                data = (
                    variant_metrics_df.filter(pl.col("complexity") == complexity)
                    .join(
                        metrics_df.select(["sample", "multiplexing"]),
                        left_on="sample_id",
                        right_on="sample",
                    )
                    .join(
                        depth_df.select(["sample", "mean_depth"]),
                        left_on="sample_id",
                        right_on="sample",
                    )
                )

                for k, metric in enumerate(metrics):
                    ax = plt.subplot(inner_gs[k])

                    if k == 0:
                        row_positions[main_row_index] = ax.get_position()
                        ax.annotate(
                            chr(ord("A") + (j * 2 + i)),
                            xy=(-0.1, 1.05),
                            xycoords="axes fraction",
                            fontsize=12,
                            fontweight="bold",
                        )
                        # Add complexity label
                        complexity_label = (
                            "High Complexity"
                            if complexity == "hc"
                            else "Low Complexity"
                        )
                        ax.annotate(
                            complexity_label,
                            xy=(-0.18, 0.5),
                            xycoords="axes fraction",
                            fontweight="bold",
                            ha="right",
                            va="center",
                            rotation=90,
                        )

                    multiplexing_values = sorted(
                        data["multiplexing"].unique().to_list()
                    )
                    colors = sns.color_palette("colorblind", len(multiplexing_values))
                    color_mapping = dict(zip(multiplexing_values, colors))

                    for multiplex in multiplexing_values:
                        subset = data.filter(pl.col("multiplexing") == multiplex)
                        ax.scatter(
                            subset["mean_depth"].to_numpy(),
                            subset[metric].to_numpy(),
                            color=color_mapping[multiplex],
                            label=(
                                str(multiplex) if i == 0 and j == 0 and k == 0 else ""
                            ),
                            s=100,
                        )

                    x = data["mean_depth"].to_numpy()
                    y = data[metric].to_numpy()

                    y_limits[metric] = (
                        min(y_limits[metric][0], y.min()),
                        max(y_limits[metric][1], y.max()),
                    )

                    corr_stats = calculate_correlation_stats(x, y)

                    x_range = np.linspace(x.min(), x.max(), 100)
                    y_fit = asymptotic_func(x_range, *corr_stats.fit_params)

                    ax.plot(
                        x_range,
                        y_fit,
                        color="gray",
                        linestyle="-",
                        linewidth=2,
                        label=(
                            "Line of Best Fit" if i == 0 and j == 0 and k == 0 else ""
                        ),
                    )
                    ax.fill_between(
                        x_range,
                        y_fit - corr_stats.confidence_intervals,
                        y_fit + corr_stats.confidence_intervals,
                        color="gray",
                        alpha=0.2,
                        label=(
                            "95% Confidence Interval"
                            if i == 0 and j == 0 and k == 0
                            else ""
                        ),
                    )

                    ax.set_title(
                        f"r={corr_stats.correlation:.2f}, p={corr_stats.p_value:.2e}"
                    )
                    ax.set_xlabel("Whole Genome Mean Depth")
                    ax.set_ylabel(metric.capitalize().replace("_", "-"))

                    if i == 0 and j == 0 and k == 0:
                        handles, labels = ax.get_legend_handles_labels()
                        legend = ax.legend(
                            handles=handles, loc="lower right", title="Multiplexing"
                        )
                        legend.get_title().set_weight("bold")

        plt.tight_layout()
        return fig

    except Exception as e:
        logger.error(f"Error creating depth vs performance plots: {str(e)}")
        raise


performance_depth_fig = plot_depth_vs_performance(
    depth_df=total_depth_df,
    snv_metrics_df=snv_rtg_metrics_dfs["ont"],
    indel_metrics_df=indel_rtg_metrics_dfs["ont"],
    metrics_df=nanoplot_qc_metrics_df,
)
No description has been provided for this image
In [38]:
@dataclass
class AncovaResult:
    """Data class for storing ANCOVA analysis results."""

    depth_effect: float
    depth_ci_low: float
    depth_ci_high: float
    depth_pvalue: float
    multiplex_effect: float
    multiplex_ci_low: float
    multiplex_ci_high: float
    multiplex_pvalue: float
    r_squared: float
    adj_r_squared: float


def _prepare_ancova_data(
    depth_df: pl.DataFrame, metrics_df: pl.DataFrame, np_metrics_df: pl.DataFrame
) -> pl.DataFrame:
    """Prepare data for ANCOVA analysis by merging relevant dataframes.

    Args:
        depth_df: DataFrame containing depth information
        metrics_df: DataFrame containing metrics information
        np_metrics_df: DataFrame containing nanoplot metrics information

    Returns:
        pl.DataFrame: Combined and processed DataFrame for ANCOVA analysis
    """
    try:
        depth_data = depth_df.group_by("sample").agg(
            pl.col("mean_depth").mean().alias("wg_mean_depth")
        )

        # First join metrics with nanoplot data
        data = metrics_df.join(
            np_metrics_df.select(["sample", "multiplexing"]),
            left_on="sample_id",
            right_on="sample",
        )

        # Then join with depth data using sample_id
        data = data.join(
            depth_data,
            left_on="sample_id",
            right_on="sample",
        )

        data = data.with_columns(
            pl.when(pl.col("multiplexing") == "multiplex")
            .then(1)
            .otherwise(0)
            .alias("multiplexing_dummy")
        )

        return data
    except Exception as e:
        logger.error(f"Error preparing ANCOVA data: {str(e)}")
        raise


def _get_ancova_statistical_significance(
    pvalue: float, thresholds: Dict[str, float] = {"***": 0.001, "**": 0.01, "*": 0.05}
) -> str:
    """Determine statistical significance notation based on p-value.

    Args:
        pvalue: P-value from statistical test
        thresholds: Dictionary of significance thresholds and their corresponding symbols,
                   default is standard thresholds (***: p<0.001, **: p<0.01, *: p<0.05)

    Returns:
        str: Significance stars ("***", "**", "*") or empty string if not significant
    """
    try:
        if not isinstance(pvalue, (int, float)):
            raise ValueError(f"P-value must be numeric, got {type(pvalue)}")

        if pvalue < 0 or pvalue > 1:
            raise ValueError(f"P-value must be between 0 and 1, got {pvalue}")

        # Sort thresholds by value in descending order to check most stringent first
        sorted_thresholds = dict(sorted(thresholds.items(), key=lambda x: x[1]))

        for symbol, threshold in sorted_thresholds.items():
            if pvalue < threshold:
                return symbol

        return ""

    except Exception as e:
        logger.error(f"Error determining statistical significance: {str(e)}")
        raise


def create_forest_plot_ancova(
    snv_results: Dict[str, AncovaResult],
    indel_results: Dict[str, AncovaResult],
    figsize: Tuple[int, int] = (12, 7),
    dpi: int = 300,
) -> plt.Figure:
    """Create forest plot for ANCOVA analysis results with colored effects and significance stars.

    Args:
        snv_results: Dictionary of SNV ANCOVA results
        indel_results: Dictionary of INDEL ANCOVA results
        figsize: Figure size (width, height)
        dpi: Figure resolution

    Returns:
        plt.Figure: Generated matplotlib figure
    """
    try:
        fig, (ax1, ax2) = plt.subplots(2, 1, figsize=figsize, dpi=dpi)
        palette = sns.color_palette("colorblind")

        def _add_results_to_plot(
            ax: plt.Axes,
            results: Dict[str, AncovaResult],
            title: str,
            show_legend: bool = True,
        ) -> None:
            # Add panel label
            ax.text(
                -0.1,
                1.1,
                "A" if title == "SNVs" else "B",
                transform=ax.transAxes,
                fontsize=14,
                fontweight="bold",
            )

            metrics = ["Precision", "Sensitivity", "F-measure"]
            complexities = ["hc", "lc"]

            # Prepare data structure
            plot_data = []
            labels = []
            for metric in metrics:
                for complexity in complexities:
                    key = f"{metric}_{complexity}"
                    plot_data.append((metric, complexity, results[key]))
                    labels.append(f"{metric} ({complexity.upper()})")

            y_positions = np.arange(len(plot_data))

            # Plot multiplexing and depth effects
            for i, (_, _, result) in enumerate(plot_data):
                # Multiplexing effect
                ax.plot(
                    [result.multiplex_effect],
                    [y_positions[i]],
                    "o",
                    color=palette[0],
                    label="Multiplexing" if i == 0 else "",
                )
                ax.plot(
                    [result.multiplex_ci_low, result.multiplex_ci_high],
                    [y_positions[i], y_positions[i]],
                    "-",
                    color=palette[0],
                )

                stars = _get_ancova_statistical_significance(result.multiplex_pvalue)
                if stars:
                    ax.text(
                        result.multiplex_effect,
                        y_positions[i] + 0.1,
                        stars,
                        ha="left",
                        va="center",
                        color=palette[0],
                    )

                # Depth effect
                ax.plot(
                    [result.depth_effect],
                    [y_positions[i]],
                    "o",
                    color=palette[1],
                    label="Depth" if i == 0 else "",
                )
                ax.plot(
                    [result.depth_ci_low, result.depth_ci_high],
                    [y_positions[i], y_positions[i]],
                    "-",
                    color=palette[1],
                )

                stars = _get_ancova_statistical_significance(result.depth_pvalue)
                if stars:
                    ax.text(
                        result.depth_effect,
                        y_positions[i] + 0.1,
                        stars,
                        ha="left",
                        va="center",
                        color=palette[1],
                    )

            # Customize plot
            ax.axvline(x=0, color="gray", linestyle="--", alpha=0.5)
            ax.set_yticks(y_positions)
            ax.set_yticklabels(labels)
            ax.set_title(f"{title}", fontweight="bold")
            ax.set_xlabel("Effect Size (with 95% Confidence Intervals)")

            if show_legend:
                legend = ax.legend(loc="lower right", title="Metrics")
                legend.get_title().set_fontweight("bold")
            elif ax.get_legend():
                ax.get_legend().remove()

        # Create plots for SNVs and INDELs
        _add_results_to_plot(ax1, snv_results, "SNVs", show_legend=True)
        _add_results_to_plot(ax2, indel_results, "Indels", show_legend=False)

        plt.tight_layout()
        return fig

    except Exception as e:
        logger.error(f"Error creating forest plot: {str(e)}")
        raise


def _perform_ancova(data: pl.DataFrame, metric: str, complexity: str) -> AncovaResult:
    """Perform Analysis of Covariance (ANCOVA) for a specific metric and complexity level.

    Args:
        data: Polars DataFrame containing the analysis data
        metric: The metric to analyze ('Precision', 'Sensitivity', or 'F-measure')
        complexity: The complexity level to analyze ('hc' or 'lc')

    Returns:
        AncovaResult: Dataclass containing the ANCOVA analysis results including:
            - Effects, confidence intervals, and p-values for depth and multiplexing
            - R-squared and adjusted R-squared values

    Raises:
        Exception: If there's an error during the ANCOVA analysis
    """
    try:
        # Create a mapping for column names
        metric_mapping = {
            "Precision": "precision",
            "Sensitivity": "sensitivity",
            "F-measure": "f_measure",
        }

        subset = data.filter(pl.col("complexity") == complexity)
        subset = subset.with_columns(pl.col("wg_mean_depth").log().alias("log_depth"))

        X = sm.add_constant(
            subset.select(["log_depth", "multiplexing_dummy"]).to_numpy()
        )
        # Use the mapped column name
        y = subset.select(metric_mapping[metric]).to_numpy().flatten()

        model = sm.OLS(y, X).fit()
        conf_int = model.conf_int()

        return AncovaResult(
            depth_effect=model.params[1],
            depth_ci_low=conf_int[1, 0],
            depth_ci_high=conf_int[1, 1],
            depth_pvalue=model.pvalues[1],
            multiplex_effect=model.params[2],
            multiplex_ci_low=conf_int[2, 0],
            multiplex_ci_high=conf_int[2, 1],
            multiplex_pvalue=model.pvalues[2],
            r_squared=model.rsquared,
            adj_r_squared=model.rsquared_adj,
        )
    except Exception as e:
        logger.error(f"Error performing ANCOVA analysis: {str(e)}")
        raise


def print_ancova_results(
    snv_results: Dict[str, AncovaResult], indel_results: Dict[str, AncovaResult]
) -> None:
    """Print formatted ANCOVA analysis results for SNVs and INDELs.

    Args:
        snv_results: Dictionary mapping metric_complexity to AncovaResult for SNVs
        indel_results: Dictionary mapping metric_complexity to AncovaResult for INDELs

    Output format:
        - Results are grouped by variant type (SNV/INDEL)
        - Each row shows metric, complexity, depth effect, and multiplexing effect
        - Effects are displayed with confidence intervals and significance stars
        - Significance levels: *** p<0.001, ** p<0.01, * p<0.05, ns = not significant
    """

    def _format_effect(
        effect: float, ci_low: float, ci_high: float, pvalue: float
    ) -> str:
        stars = _get_ancova_statistical_significance(pvalue)
        significance = f" {stars}" if stars else " (ns)"
        return (
            f"{effect:.3f} [{ci_low:.3f}, {ci_high:.3f}]{significance} "
            f"(p={pvalue:.4f})"
        )

    def _print_variant_results(
        results: Dict[str, AncovaResult], variant_type: str
    ) -> None:
        print(f"\n{variant_type} Results:")
        print("=" * 80)
        print(
            f"{'Metric':<15} {'Complexity':<10} {'Depth Effect':<35} {'Multiplexing Effect':<35}"
        )
        print("-" * 80)

        metrics = ["Precision", "Sensitivity", "F-measure"]
        complexities = ["hc", "lc"]

        for metric in metrics:
            for complexity in complexities:
                key = f"{metric}_{complexity}"
                result = results[key]

                depth_effect = _format_effect(
                    result.depth_effect,
                    result.depth_ci_low,
                    result.depth_ci_high,
                    result.depth_pvalue,
                )

                multiplex_effect = _format_effect(
                    result.multiplex_effect,
                    result.multiplex_ci_low,
                    result.multiplex_ci_high,
                    result.multiplex_pvalue,
                )

                print(
                    f"{metric:<15} {complexity.upper():<10} {depth_effect:<35} "
                    f"{multiplex_effect:<35}"
                )
            print("-" * 80)

    print("\nANCOVA Analysis Results")
    print("=====================")
    print("Significance levels: *** p<0.001, ** p<0.01, * p<0.05, ns = not significant")

    _print_variant_results(snv_results, "SNV")
    _print_variant_results(indel_results, "INDEL")


def run_ancova_analysis(
    depth_df: pl.DataFrame,
    snv_ont_metrics_df: pl.DataFrame,
    indel_ont_metrics_df: pl.DataFrame,
    np_metrics_df: pl.DataFrame,
) -> Tuple[Dict[str, AncovaResult], Dict[str, AncovaResult]]:
    """Execute ANCOVA analysis for SNV and Indel metrics.

    Args:
        depth_df: Polars DataFrame containing depth information
        snv_ont_metrics_df: Polars DataFrame containing SNV metrics
        indel_ont_metrics_df: Polars DataFrame containing INDEL metrics
        np_metrics_df: Polars DataFrame containing nanoplot metrics

    Returns:
        Tuple[Dict[str, AncovaResult], Dict[str, AncovaResult]]: Two dictionaries containing
            ANCOVA results for SNVs and INDELs respectively, with keys formatted as
            'metric_complexity' (e.g., 'Precision_hc')

    Raises:
        Exception: If there's an error during any stage of the analysis
    """
    try:
        logger.info("Starting ANCOVA analysis")

        snv_data = _prepare_ancova_data(depth_df, snv_ont_metrics_df, np_metrics_df)
        indel_data = _prepare_ancova_data(depth_df, indel_ont_metrics_df, np_metrics_df)

        metrics = ["Precision", "Sensitivity", "F-measure"]
        complexities = ["hc", "lc"]

        snv_results: Dict[str, AncovaResult] = {}
        indel_results: Dict[str, AncovaResult] = {}

        for metric in metrics:
            for complexity in complexities:
                key = f"{metric}_{complexity}"
                logger.info(f"Analyzing {key}")

                snv_results[key] = _perform_ancova(snv_data, metric, complexity)
                indel_results[key] = _perform_ancova(indel_data, metric, complexity)

        logger.info("ANCOVA analysis completed successfully")
        return snv_results, indel_results

    except Exception as e:
        logger.error(f"Error in ANCOVA analysis: {str(e)}")
        raise


snv_ancova_results, indel_ancova_results = run_ancova_analysis(
    depth_df=total_depth_df,
    snv_ont_metrics_df=snv_rtg_metrics_dfs["ont"],
    indel_ont_metrics_df=indel_rtg_metrics_dfs["ont"],
    np_metrics_df=nanoplot_qc_metrics_df,
)

print_ancova_results(snv_ancova_results, indel_ancova_results)

forest_plot_ancova = create_forest_plot_ancova(
    snv_results=snv_ancova_results,
    indel_results=indel_ancova_results,
)
__main__ - INFO - Starting ANCOVA analysis
__main__ - INFO - Analyzing Precision_hc
__main__ - INFO - Analyzing Precision_lc
__main__ - INFO - Analyzing Sensitivity_hc
__main__ - INFO - Analyzing Sensitivity_lc
__main__ - INFO - Analyzing F-measure_hc
__main__ - INFO - Analyzing F-measure_lc
__main__ - INFO - ANCOVA analysis completed successfully
ANCOVA Analysis Results
=====================
Significance levels: *** p<0.001, ** p<0.01, * p<0.05, ns = not significant

SNV Results:
================================================================================
Metric          Complexity Depth Effect                        Multiplexing Effect                
--------------------------------------------------------------------------------
Precision       HC         0.015 [0.001, 0.029] * (p=0.0341)   -0.005 [-0.017, 0.006] (ns) (p=0.3128)
Precision       LC         0.022 [0.007, 0.038] ** (p=0.0088)  -0.007 [-0.019, 0.005] (ns) (p=0.2279)
--------------------------------------------------------------------------------
Sensitivity     HC         0.037 [0.009, 0.065] * (p=0.0133)   -0.009 [-0.031, 0.013] (ns) (p=0.3915)
Sensitivity     LC         0.034 [0.011, 0.056] ** (p=0.0072)  -0.008 [-0.026, 0.010] (ns) (p=0.3735)
--------------------------------------------------------------------------------
F-measure       HC         0.026 [0.006, 0.046] * (p=0.0143)   -0.007 [-0.023, 0.009] (ns) (p=0.3412)
F-measure       LC         0.028 [0.011, 0.046] ** (p=0.0038)  -0.007 [-0.021, 0.006] (ns) (p=0.2608)
--------------------------------------------------------------------------------

INDEL Results:
================================================================================
Metric          Complexity Depth Effect                        Multiplexing Effect                
--------------------------------------------------------------------------------
Precision       HC         0.131 [0.094, 0.168] *** (p=0.0000) -0.027 [-0.056, 0.003] (ns) (p=0.0717)
Precision       LC         0.252 [0.201, 0.304] *** (p=0.0000) 0.025 [-0.016, 0.066] (ns) (p=0.2099)
--------------------------------------------------------------------------------
Sensitivity     HC         0.065 [0.023, 0.108] ** (p=0.0063)  -0.012 [-0.046, 0.022] (ns) (p=0.4423)
Sensitivity     LC         0.087 [0.063, 0.111] *** (p=0.0000) -0.007 [-0.026, 0.013] (ns) (p=0.4584)
--------------------------------------------------------------------------------
F-measure       HC         0.102 [0.063, 0.141] *** (p=0.0001) -0.021 [-0.052, 0.010] (ns) (p=0.1580)
F-measure       LC         0.174 [0.153, 0.194] *** (p=0.0000) 0.005 [-0.011, 0.022] (ns) (p=0.5005)
--------------------------------------------------------------------------------
No description has been provided for this image

3. Impact of read length on variant calling¶

In [39]:
@dataclass
class PerformanceCorrelation:
    """Data class for storing performance correlation results."""

    correlation: float
    p_value: float
    fit_params: Tuple[float, float]
    confidence_intervals: np.ndarray


def plot_read_length_vs_performance(
    snv_metrics_df: pl.DataFrame,
    indel_metrics_df: pl.DataFrame,
    metrics_df: pl.DataFrame,
    figsize: Tuple[int, int] = (14, 16),
    dpi: int = 300,
) -> plt.Figure:
    """Create plots comparing variant calling performance metrics against read length.

    Args:
        snv_metrics_df: DataFrame containing SNV calling metrics
        indel_metrics_df: DataFrame containing indel calling metrics
        metrics_df: DataFrame containing sample metadata with mean read lengths
        figsize: Figure size in inches
        dpi: Figure resolution

    Returns:
        plt.Figure: Figure object containing the plots
    """
    try:
        metrics = ["precision", "sensitivity", "f_measure"]
        variant_types = ["SNV", "Indel"]
        complexities = ["hc", "lc"]

        fig = plt.figure(figsize=figsize, dpi=dpi)
        gs = gridspec.GridSpec(5, 1, height_ratios=[1, 1, 0.05, 1, 1])

        # Add section titles
        fig.text(
            0.5,
            0.99,
            "SNV Performance Metrics vs Mean Read Length",
            ha="center",
            fontsize=12,
            fontweight="bold",
        )
        fig.text(
            0.5,
            0.468,
            "Indel Performance Metrics vs Mean Read Length",
            ha="center",
            fontsize=12,
            fontweight="bold",
        )

        for i, complexity in enumerate(complexities):
            for j, (variant_type, variant_metrics_df) in enumerate(
                zip(variant_types, [snv_metrics_df, indel_metrics_df])
            ):
                main_row_index = j * 3 + i if j == 1 else i
                inner_gs = gridspec.GridSpecFromSubplotSpec(
                    1, 3, subplot_spec=gs[main_row_index]
                )

                # Filter and join data
                data = variant_metrics_df.filter(
                    pl.col("complexity") == complexity
                ).join(
                    metrics_df.select(["sample", "multiplexing", "mean_read_length"]),
                    left_on="sample_id",
                    right_on="sample",
                )

                if len(data) == 0:
                    logger.warning(
                        f"No data after joining for {variant_type} ({complexity})"
                    )
                    continue

                for k, metric in enumerate(metrics):
                    ax = plt.subplot(inner_gs[k])

                    if k == 0:
                        ax.annotate(
                            chr(ord("A") + (j * 2 + i)),
                            xy=(-0.1, 1.05),
                            xycoords="axes fraction",
                            fontsize=12,
                            fontweight="bold",
                        )
                        complexity_label = (
                            "High Complexity"
                            if complexity == "hc"
                            else "Low Complexity"
                        )
                        ax.annotate(
                            complexity_label,
                            xy=(-0.18, 0.5),
                            xycoords="axes fraction",
                            fontweight="bold",
                            ha="right",
                            va="center",
                            rotation=90,
                        )

                    multiplexing_values = sorted(
                        data["multiplexing"].unique().to_list()
                    )
                    colors = sns.color_palette("colorblind", len(multiplexing_values))
                    color_mapping = dict(zip(multiplexing_values, colors))

                    for multiplex in multiplexing_values:
                        subset = data.filter(pl.col("multiplexing") == multiplex)
                        x = subset["mean_read_length"].to_numpy()
                        y = subset[metric].to_numpy()

                        ax.scatter(
                            x,
                            y,
                            color=color_mapping[multiplex],
                            label=(
                                str(multiplex) if i == 0 and j == 0 and k == 0 else ""
                            ),
                            s=100,
                        )

                    x = data["mean_read_length"].to_numpy()
                    y = data[metric].to_numpy()

                    slope, intercept, r_value, p_value, std_err = stats.linregress(x, y)
                    x_range = np.linspace(x.min(), x.max(), 100)
                    y_fit = slope * x_range + intercept

                    ax.plot(
                        x_range,
                        y_fit,
                        color="gray",
                        linestyle="-",
                        linewidth=2,
                        label=(
                            "Line of Best Fit" if i == 0 and j == 0 and k == 0 else ""
                        ),
                    )

                    n = len(x)
                    y_err = np.sqrt(
                        np.sum((y - (slope * x + intercept)) ** 2) / (n - 2)
                    )
                    ci = (
                        stats.t.ppf(0.975, n - 2)
                        * y_err
                        * np.sqrt(
                            1 / n
                            + (x_range - np.mean(x)) ** 2
                            / np.sum((x - np.mean(x)) ** 2)
                        )
                    )

                    ax.fill_between(
                        x_range,
                        y_fit - ci,
                        y_fit + ci,
                        color="gray",
                        alpha=0.2,
                        label=(
                            "95% Confidence Interval"
                            if i == 0 and j == 0 and k == 0
                            else ""
                        ),
                    )

                    ax.set_title(f"r={r_value:.2f}, p={p_value:.2e}")
                    ax.set_xlabel("Mean Read Length")
                    ax.set_ylabel(metric.capitalize())

                    if i == 0 and j == 0 and k == 0:
                        handles, labels = ax.get_legend_handles_labels()
                        legend = ax.legend(
                            handles=handles,
                            loc="lower left",
                            title="Multiplexing",
                        )
                        legend.get_title().set_weight("bold")

        plt.tight_layout()
        return fig

    except Exception as e:
        logger.error(f"Error creating read length vs performance plots: {str(e)}")
        raise


performance_readlength_fig = plot_read_length_vs_performance(
    snv_metrics_df=snv_rtg_metrics_dfs["ont"],
    indel_metrics_df=indel_rtg_metrics_dfs["ont"],
    metrics_df=nanoplot_qc_metrics_df,
)
No description has been provided for this image

SV Benchmark¶

1. SV Consensus Calls¶

In [40]:
@dataclass
class SVMetrics:
    """Data class for structural variant metrics."""

    type: str
    length: Optional[int]
    chrom: str
    start: int
    end: int
    allele_idx: int


@dataclass
class SVAnalysisConfig:
    """Configuration for SV analysis."""

    base_path: Path
    technologies: Tuple[str, ...] = ("ont", "illumina")


def parse_int_or_first(value: Any) -> int:
    """
    Parse integer value from various input types.

    Args:
        value: Input value that could be int, float, str, or tuple

    Returns:
        Parsed integer value

    Raises:
        ValueError: If value cannot be parsed as integer
    """
    if isinstance(value, (int, float)):
        return int(value)
    elif isinstance(value, str):
        return int(value.split("/")[0])
    elif isinstance(value, tuple):
        return int(value[0])
    else:
        raise ValueError(f"Unexpected value type: {type(value)}")


def handle_str(
    record: Any, alt: str, chrom: str, start: int, end: int, alt_idx: int
) -> Optional[Dict[str, Any]]:
    """
    Handle Short Tandem Repeat (STR) variants.

    Args:
        record: VCF record
        alt: Alternative allele
        chrom: Chromosome
        start: Start position
        end: End position
        alt_idx: Alternative allele index

    Returns:
        Dictionary containing STR information or None if invalid
    """
    try:
        repcn = record.samples[0].get("REPCN")

        if repcn is not None:
            if isinstance(repcn, tuple):
                repeat_count = parse_int_or_first(repcn[alt_idx])
            else:
                repeat_count = parse_int_or_first(repcn)
        elif alt.startswith("<STR"):
            str_alleles = record.alts
            current_alt = str_alleles[alt_idx]
            repeat_count = int(current_alt[4:-1])
        else:
            return None

        ru = record.info.get("RU", "")
        sv_len = repeat_count * len(ru)

        return {
            "type": "STR",
            "length": sv_len,
            "chrom": chrom,
            "start": start,
            "end": end,
        }
    except Exception as e:
        logger.error(f"Error handling STR variant: {str(e)}")
        return None


def handle_symbolic_allele(
    record: Any, alt: str, chrom: str, start: int, end: int, sv_type: str, alt_idx: int
) -> Dict[str, Any]:
    """
    Handle symbolic allele variants.

    Args:
        record: VCF record
        alt: Alternative allele
        chrom: Chromosome
        start: Start position
        end: End position
        sv_type: Structural variant type
        alt_idx: Alternative allele index

    Returns:
        Dictionary containing symbolic allele information
    """
    sv_len = None

    if sv_type == "INV" and "SVINSLEN" in record.info:
        sv_len = record.info.get("SVINSLEN")
        if isinstance(sv_len, tuple):
            sv_len = sv_len[alt_idx] if len(sv_len) > alt_idx else sv_len[0]

    if sv_len is None:
        sv_len = record.info.get("SVLEN")
        if isinstance(sv_len, tuple):
            sv_len = sv_len[alt_idx] if len(sv_len) > alt_idx else sv_len[0]

    if sv_len is None and sv_type == "INS":
        left_seq = record.info.get("LEFT_SVINSSEQ", "")
        right_seq = record.info.get("RIGHT_SVINSSEQ", "")
        sv_len = len(left_seq) + len(right_seq)

    if sv_len is None:
        sv_len = end - start

    return {
        "type": sv_type,
        "length": abs(sv_len) if sv_len is not None else None,
        "chrom": chrom,
        "start": start,
        "end": end,
    }


def handle_standard_sv(
    record: Any, alt: str, chrom: str, start: int, end: int, sv_type: str, alt_idx: int
) -> Dict[str, Any]:
    """
    Handle standard structural variants.

    Args:
        record: VCF record
        alt: Alternative allele
        chrom: Chromosome
        start: Start position
        end: End position
        sv_type: Structural variant type
        alt_idx: Alternative allele index

    Returns:
        Dictionary containing standard SV information
    """
    if "SVLEN" in record.info:
        sv_len = record.info["SVLEN"]
        if isinstance(sv_len, tuple):
            sv_len = sv_len[alt_idx] if len(sv_len) > alt_idx else sv_len[0]
    elif sv_type == "INS":
        left_seq = record.info.get("LEFT_SVINSSEQ", "")
        right_seq = record.info.get("RIGHT_SVINSSEQ", "")
        sv_len = len(left_seq) + len(right_seq)
    else:
        ref_len = len(record.ref)
        alt_len = len(alt)
        sv_len = alt_len - ref_len if sv_type == "INS" else ref_len - alt_len

    return {
        "type": sv_type,
        "length": abs(sv_len),
        "chrom": chrom,
        "start": start,
        "end": end,
    }


def extract_sv_info(record: Any, alt: str, alt_idx: int) -> Optional[Dict[str, Any]]:
    """
    Extract structural variant information from VCF record.

    Args:
        record: VCF record
        alt: Alternative allele
        alt_idx: Alternative allele index

    Returns:
        Dictionary containing SV information or None if invalid
    """
    chrom = record.chrom
    start = record.pos
    end = record.info.get("END", start)
    sv_type = record.info.get("SVTYPE", "Unknown")

    if sv_type == "STR" or (isinstance(alt, str) and alt.startswith("<STR")):
        return handle_str(record, alt, chrom, start, end, alt_idx)
    elif isinstance(alt, str) and alt.startswith("<") and alt.endswith(">"):
        return handle_symbolic_allele(record, alt, chrom, start, end, sv_type, alt_idx)
    else:
        return handle_standard_sv(record, alt, chrom, start, end, sv_type, alt_idx)


def read_sv_vcf_file(file_path: Path) -> pl.DataFrame:
    """
    Read and parse structural variant VCF file.

    Args:
        file_path: Path to the VCF file

    Returns:
        Polars DataFrame containing parsed SV information

    Raises:
        FileNotFoundError: If VCF file does not exist
        ValueError: If VCF file is malformed
    """
    try:
        svs: List[Dict] = []
        with pysam.VariantFile(file_path) as vcf:
            for record in vcf:
                for alt_idx, alt in enumerate(record.alts):
                    sv_info = extract_sv_info(record, alt, alt_idx)
                    if sv_info:
                        sv_info["allele_idx"] = alt_idx
                        svs.append(sv_info)
        return pl.DataFrame(svs)
    except FileNotFoundError:
        logger.error(f"VCF file not found: {file_path}")
        raise
    except Exception as e:
        logger.error(f"Error reading VCF file {file_path}: {str(e)}")
        raise ValueError(f"Error parsing VCF file: {str(e)}")


def analyze_sv_calls(
    sample_id: str,
    ont_id: str,
    illumina_id: str,
    base_path: Path = Path("/scratch/prj/ppn_als_longread/ont-benchmark"),
) -> pl.DataFrame:
    """
    Analyze structural variant calls from ONT and Illumina data.

    Args:
        sample_id: Sample identifier
        ont_id: ONT sample identifier
        illumina_id: Illumina sample identifier
        base_path: Base path for input files

    Returns:
        Polars DataFrame containing combined SV analysis results

    Raises:
        FileNotFoundError: If required input files are missing
    """
    try:
        sv_path = base_path / "output/sv/survivor"
        ont_file = sv_path / ont_id / f"{ont_id}.ont.sv_str.filtered.vcf"
        illumina_file = sv_path / ont_id / f"{illumina_id}.illumina.sv.filtered.vcf"
        merged_file = sv_path / ont_id / f"{ont_id}_{illumina_id}_merged.vcf"

        ont_svs = read_sv_vcf_file(ont_file)
        illumina_svs = read_sv_vcf_file(illumina_file)
        merged_svs = read_sv_vcf_file(merged_file)

        # Create unique identifiers
        ont_svs = ont_svs.with_columns(
            pl.concat_str(
                pl.col("chrom"),
                pl.col("start"),
                pl.col("end"),
                pl.col("type"),
                separator="_",
            ).alias("sv_id")
        )
        illumina_svs = illumina_svs.with_columns(
            pl.concat_str(
                pl.col("chrom"),
                pl.col("start"),
                pl.col("end"),
                pl.col("type"),
                separator="_",
            ).alias("sv_id")
        )
        merged_svs = merged_svs.with_columns(
            pl.concat_str(
                pl.col("chrom"),
                pl.col("start"),
                pl.col("end"),
                pl.col("type"),
                separator="_",
            ).alias("sv_id")
        )

        # Combine and mark sources
        all_svs = pl.concat([ont_svs, illumina_svs]).unique(subset="sv_id")
        all_svs = all_svs.with_columns(
            [
                pl.col("sv_id").is_in(ont_svs["sv_id"]).alias("ONT"),
                pl.col("sv_id").is_in(illumina_svs["sv_id"]).alias("Illumina"),
                pl.col("sv_id").is_in(merged_svs["sv_id"]).alias("Merged"),
                pl.lit(sample_id).alias("sample_id"),
            ]
        )

        return all_svs.drop("sv_id")

    except Exception as e:
        logger.error(f"Error analyzing SV calls for sample {sample_id}: {str(e)}")
        raise


sv_data_list = []

for row in sample_ids.iter_rows(named=True):
    try:
        sample_data = analyze_sv_calls(
            sample_id=row["ont_id"], ont_id=row["ont_id"], illumina_id=row["lp_id"]
        )
        sv_data_list.append(sample_data)
    except Exception as e:
        logger.error(f"Failed to process sample {row['ont_id']}: {str(e)}")
        continue

sv_data_df = pl.concat(sv_data_list)

logger.info(f"Total SV calls processed: {sv_data_df.height}")

sv_data_df
__main__ - INFO - Total SV calls processed: 508922
Out[40]:
shape: (508_922, 10)
typelengthchromstartendallele_idxONTIlluminaMergedsample_id
stri64stri64i64i64boolboolboolstr
"INS"80"chr7"50802725508027250truefalsefalse"A046_12"
"DEL"61"chr7"48848787488487870falsetruetrue"A046_12"
"DEL"54"chr1"1016578781016578780truefalsefalse"A046_12"
"INS"31"chr11"55765038557650380truefalsefalse"A046_12"
"INS"132"chr8"5357315357310truefalsefalse"A046_12"
…………………………
"DEL"673"chr16"32507512325075120falsetruefalse"A162_09"
"DEL"205"chr16"88122050881220500falsetruetrue"A162_09"
"INS"512"chr17"26701545267015450truefalsefalse"A162_09"
"INS"54"chr3"76698102766981020truefalsefalse"A162_09"
"DEL"165"chrX"3645763645760truefalsefalse"A162_09"
In [41]:
def compare_sv_counts(sv_data_df: pl.DataFrame) -> pl.DataFrame:
    """
    Calculate SV counts across different technologies for each sample.

    Args:
        sv_data_df: Input DataFrame containing SV data with columns:
                   sample_id, ONT, Illumina, and Merged

    Returns:
        DataFrame with anonymized sample IDs and counts for each technology

    Raises:
        ValueError: If required columns are missing
    """
    try:
        required_cols = {"sample_id", "ONT", "Illumina", "Merged"}
        if not all(col in sv_data_df.columns for col in required_cols):
            missing = required_cols - set(sv_data_df.columns)
            raise ValueError(f"Missing required columns: {missing}")

        counts = (
            sv_data_df.group_by("sample_id")
            .agg(
                [
                    pl.col("ONT").sum().alias("long-read"),
                    pl.col("Illumina").sum().alias("short-read"),
                    pl.col("Merged").sum().alias("consensus"),
                ]
            )
            .sort("sample_id")
        )

        # Add anonymized sample IDs
        counts = counts.with_columns(
            [
                pl.Series(
                    name="anonymised_sample",
                    values=[f"Sample {i+1}" for i in range(counts.height)],
                )
            ]
        )

        return counts.select(
            ["anonymised_sample", "long-read", "short-read", "consensus"]
        )

    except Exception as e:
        logger.error(f"Error comparing SV counts: {str(e)}")
        raise


def plot_sv_counts(
    sv_counts_df: pl.DataFrame,
    figsize: Tuple[int, int] = (12, 6),
    dpi: int = 300,
    gs: Optional[gridspec.GridSpec] = None,
) -> Optional[plt.Figure]:
    """
    Plot structural variant counts across different technologies for each sample.

    Args:
        sv_counts_df: DataFrame containing SV count data
        figsize: Figure size as (width, height)
        dpi: Figure resolution
        gs: Optional GridSpec for plotting within a larger figure

    Returns:
        Figure object if created independently (gs=None)

    Raises:
        ValueError: If required columns are missing
    """
    try:
        required_cols = {"anonymised_sample", "long-read", "short-read", "consensus"}
        if not all(col in sv_counts_df.columns for col in required_cols):
            missing = required_cols - set(sv_counts_df.columns)
            raise ValueError(f"Missing required columns: {missing}")

        if gs is None:
            fig, ax = plt.subplots(figsize=figsize, dpi=dpi)
        else:
            fig = plt.gcf()
            ax = fig.add_subplot(gs[0, 0])

        plot_data = sv_counts_df.unpivot(
            index=["anonymised_sample"],
            on=["long-read", "short-read", "consensus"],
            variable_name="Technology",
            value_name="Count",
        )

        sns.barplot(
            data=plot_data, x="anonymised_sample", y="Count", hue="Technology", ax=ax
        )

        ax.set_title("SV Call Counts by Sample and Platform")
        ax.set_xlabel("Sample")
        ax.set_ylabel("Number of SV Calls")
        legend = ax.legend(title="Technology")
        legend.get_title().set_weight("bold")

        locs, labels = ax.get_xticks(), ax.get_xticklabels()
        ax.set_xticks([loc + 0.1 for loc in locs])

        for tick in ax.get_xticklabels():
            tick.set_rotation(45)
            tick.set_ha("right")

        if gs is None:
            plt.tight_layout()
            return fig
        else:
            return None

    except Exception as e:
        logger.error(f"Error plotting SV counts: {str(e)}")
        raise


sv_counts_df = compare_sv_counts(sv_data_df)

sv_counts_plot = plot_sv_counts(sv_counts_df)
No description has been provided for this image
In [42]:
def calculate_consensus_percentages(sv_counts_df: pl.DataFrame) -> pl.DataFrame:
    """
    Calculate consensus percentages between ONT/Illumina and consensus calls.

    Args:
        sv_counts_df: DataFrame containing SV count data with columns:
                     anonymised_sample, long-read, short-read, and consensus

    Returns:
        DataFrame with added consensus percentage columns and printed statistics

    Raises:
        ValueError: If required columns are missing
    """
    try:
        required_cols = {"long-read", "short-read", "consensus"}
        if not all(col in sv_counts_df.columns for col in required_cols):
            missing = required_cols - set(sv_counts_df.columns)
            raise ValueError(f"Missing required columns: {missing}")

        result_df = sv_counts_df.with_columns(
            [
                (pl.col("consensus") / pl.col("long-read") * 100)
                .fill_null(0)
                .alias("ONT_Consensus_Percent"),
                (pl.col("consensus") / pl.col("short-read") * 100)
                .fill_null(0)
                .alias("Illumina_Consensus_Percent"),
            ]
        )

        # Calculate statistics
        ont_stats = result_df.select(
            [
                pl.col("ONT_Consensus_Percent").mean().alias("mean"),
                pl.col("ONT_Consensus_Percent").std().alias("std"),
            ]
        ).row(0)

        illumina_stats = result_df.select(
            [
                pl.col("Illumina_Consensus_Percent").mean().alias("mean"),
                pl.col("Illumina_Consensus_Percent").std().alias("std"),
            ]
        ).row(0)

        logger.info(
            f"Average consensus percentage for ONT: "
            f"{ont_stats[0]:.2f}% ± {ont_stats[1]:.2f}% (mean ± SD)"
        )
        logger.info(
            f"Average consensus percentage for Illumina: "
            f"{illumina_stats[0]:.2f}% ± {illumina_stats[1]:.2f}% (mean ± SD)"
        )

        return result_df

    except Exception as e:
        logger.error(f"Error calculating consensus percentages: {str(e)}")
        raise


sv_consensus_df = calculate_consensus_percentages(sv_counts_df)
__main__ - INFO - Average consensus percentage for ONT: 20.61% ± 1.60% (mean ± SD)
__main__ - INFO - Average consensus percentage for Illumina: 57.87% ± 10.11% (mean ± SD)
In [43]:
def calculate_average_difference(sv_counts_df: pl.DataFrame) -> pl.DataFrame:
    """
    Calculate average ratio difference between ONT and Illumina SV counts.

    Args:
        sv_counts_df: DataFrame containing SV count data with long-read and short-read columns

    Returns:
        DataFrame with added ONT/Illumina ratio column

    Raises:
        ValueError: If required columns are missing
    """
    try:
        required_cols = {"long-read", "short-read"}
        if not all(col in sv_counts_df.columns for col in required_cols):
            missing = required_cols - set(sv_counts_df.columns)
            raise ValueError(f"Missing required columns: {missing}")

        result_df = sv_counts_df.with_columns(
            (pl.col("long-read") / pl.col("short-read"))
            .fill_null(0)
            .alias("ONT_Illumina_Ratio")
        )

        # Calculate statistics
        stats = result_df.select(
            [
                pl.col("ONT_Illumina_Ratio").mean().alias("mean"),
                pl.col("ONT_Illumina_Ratio").std().alias("std"),
            ]
        ).row(0)

        logger.info(
            f"Average ratio of SV counts between ONT and Illumina: "
            f"{stats[0]:.2f} ± {stats[1]:.2f} (mean ± SD)"
        )

        return result_df

    except Exception as e:
        logger.error(f"Error calculating average difference: {str(e)}")
        raise


average_diff_df = calculate_average_difference(sv_counts_df)
__main__ - INFO - Average ratio of SV counts between ONT and Illumina: 2.86 ± 0.70 (mean ± SD)

2. SV Size Distribution¶

In [44]:
def plot_sv_size_distributions(
    sv_data_df: pl.DataFrame,
    figsize: Tuple[int, int] = (12, 6),
    dpi: int = 300,
    gs: Optional[gridspec.GridSpec] = None,
) -> Optional[plt.Figure]:
    """
    Plot length distributions of structural variants for ONT and Illumina data.

    Args:
        sv_data_df: DataFrame containing SV data with columns 'length', 'ONT', and 'Illumina'
        figsize: Figure size as (width, height)
        dpi: Figure resolution
        gs: Optional GridSpec for plotting within a larger figure

    Returns:
        Figure object if created independently (gs=None)

    Raises:
        ValueError: If required columns are missing or data is invalid
    """
    try:
        required_cols = {"length", "ONT", "Illumina"}
        if not all(col in sv_data_df.columns for col in required_cols):
            missing = required_cols - set(sv_data_df.columns)
            raise ValueError(f"Missing required columns: {missing}")

        if gs is None:
            fig, ax = plt.subplots(figsize=figsize, dpi=dpi)
        else:
            fig = plt.gcf()
            ax = fig.add_subplot(gs[0, 0])

        # Extract and clean length data for each technology
        ont_lengths = (
            sv_data_df.filter(pl.col("ONT"))
            .select("length")
            .drop_nulls()
            .filter(pl.col("length").is_finite())
            .get_column("length")
            .to_list()
        )

        illumina_lengths = (
            sv_data_df.filter(pl.col("Illumina"))
            .select("length")
            .drop_nulls()
            .filter(pl.col("length").is_finite())
            .get_column("length")
            .to_list()
        )

        # Plot distributions
        sns.histplot(
            ont_lengths,
            log_scale=True,
            bins=50,
            stat="density",
            kde=True,
            alpha=0.7,
            label="long-read",
            ax=ax,
        )
        sns.histplot(
            illumina_lengths,
            log_scale=True,
            bins=50,
            stat="density",
            kde=True,
            alpha=0.7,
            label="short-read",
            ax=ax,
        )

        ax.set_title("SV Size Distribution")
        ax.set_xlabel("SV Size (bp)")
        ax.set_ylabel("Density")
        legend = ax.legend(title="Technology")
        legend.get_title().set_weight("bold")

        if gs is None:
            plt.tight_layout()
            return fig
        else:
            return None

    except Exception as e:
        logger.error(f"Error plotting SV length distributions: {str(e)}")
        raise


sv_size_plot = plot_sv_size_distributions(sv_data_df)
No description has been provided for this image
In [45]:
@dataclass
class SVSizeStats:
    """Data class for structural variant size statistics."""

    maximum: float
    minimum: float
    mean: float
    std: float
    median: float


def calculate_sv_size_stats(lengths: List[float]) -> SVSizeStats:
    """
    Calculate statistical measures for structural variant sizes.

    Args:
        lengths: List of SV lengths to analyze

    Returns:
        SVSizeStats object containing calculated statistics

    Raises:
        ValueError: If input list is empty
    """
    try:
        if not lengths:
            raise ValueError("Empty length list provided")

        return SVSizeStats(
            maximum=float(np.max(lengths)),
            minimum=float(np.min(lengths)),
            mean=float(np.mean(lengths)),
            std=float(np.std(lengths)),
            median=float(np.median(lengths)),
        )
    except Exception as e:
        logger.error(f"Error calculating SV size statistics: {str(e)}")
        raise


def format_number(num: Union[int, float]) -> str:
    """
    Format numbers with appropriate separators and decimal places.

    Args:
        num: Number to format

    Returns:
        Formatted string representation of the number
    """
    try:
        if isinstance(num, (int, np.integer)):
            return f"{num:,d}"
        elif isinstance(num, float):
            return f"{num:,.2f}"
        return str(num)
    except Exception as e:
        logger.error(f"Error formatting number {num}: {str(e)}")
        return str(num)


def analyze_sv_size_distributions(sv_data_df: pl.DataFrame) -> Dict[str, SVSizeStats]:
    """
    Analyze size distributions of structural variants across technologies.

    Args:
        sv_data_df: DataFrame containing SV data with length and technology columns

    Returns:
        Dictionary mapping technology names to their SVSizeStats

    Raises:
        ValueError: If required columns are missing
    """
    try:
        required_cols = {"length", "ONT", "Illumina"}
        if not all(col in sv_data_df.columns for col in required_cols):
            missing = required_cols - set(sv_data_df.columns)
            raise ValueError(f"Missing required columns: {missing}")

        # Extract lengths for each technology
        ont_lengths = (
            sv_data_df.filter(pl.col("ONT"))
            .select("length")
            .filter(pl.col("length").is_finite())
            .to_series()
            .to_list()
        )

        illumina_lengths = (
            sv_data_df.filter(pl.col("Illumina"))
            .select("length")
            .filter(pl.col("length").is_finite())
            .to_series()
            .to_list()
        )

        # Calculate statistics
        stats = {
            "ONT": calculate_sv_size_stats(ont_lengths),
            "Illumina": calculate_sv_size_stats(illumina_lengths),
        }

        # Print results
        for tech, tech_stats in stats.items():
            print(f"\n{tech} SV Length Statistics:")
            print("=" * 40)
            for stat_name, value in tech_stats.__dict__.items():
                print(f"  {stat_name.capitalize():6s}: {format_number(value)}")

        return stats
    except Exception as e:
        logger.error(f"Error analyzing SV size distributions: {str(e)}")
        raise


sv_size_dist = analyze_sv_size_distributions(sv_data_df)
ONT SV Length Statistics:
========================================
  Maximum: 129,371,498.00
  Minimum: 12.00
  Mean  : 3,527.59
  Std   : 482,253.95
  Median: 80.00

Illumina SV Length Statistics:
========================================
  Maximum: 6,064.00
  Minimum: 2.00
  Mean  : 165.62
  Std   : 164.47
  Median: 96.00

3. SV Types¶

In [46]:
@dataclass
class SVTypeStats:
    """Data class for structural variant type statistics."""

    mean: float
    median: float
    std_dev: float


def calculate_sv_type_stats(counts: List[int]) -> SVTypeStats:
    """
    Calculate statistics for SV type counts.

    Args:
        counts: List of counts for a specific SV type

    Returns:
        SVTypeStats object containing calculated statistics

    Raises:
        ValueError: If input list is empty
    """
    try:
        if not counts:
            raise ValueError("Empty counts list provided")

        return SVTypeStats(
            mean=float(np.mean(counts)),
            median=float(np.median(counts)),
            std_dev=float(np.std(counts)),
        )
    except Exception as e:
        print(f"Error calculating SV type statistics: {str(e)}")
        raise


def analyze_sv_types(sv_data_df: pl.DataFrame) -> Tuple[pl.DataFrame, pl.DataFrame]:
    """
    Analyze distribution of structural variant types across samples and technologies.

    Args:
        sv_data_df: DataFrame containing SV data with type, sample_id, and technology columns

    Returns:
        Tuple containing:
            - DataFrame with SV type counts per sample and platform
            - DataFrame with statistical summary of SV types across platforms

    Raises:
        ValueError: If required columns are missing
    """
    try:
        required_cols = {"type", "sample_id", "ONT", "Illumina"}
        if not all(col in sv_data_df.columns for col in required_cols):
            missing = required_cols - set(sv_data_df.columns)
            raise ValueError(f"Missing required columns: {missing}")

        # Initialize data structures
        type_counts: Dict[Tuple[str, str], Dict[str, int]] = defaultdict(dict)
        type_stats: Dict[str, Dict[str, List[int]]] = defaultdict(
            lambda: defaultdict(list)
        )

        # Process each sample and platform
        for sample_id in sv_data_df.get_column("sample_id").unique():
            sample_data = sv_data_df.filter(pl.col("sample_id") == sample_id)

            for platform in ["ONT", "Illumina"]:
                platform_data = sample_data.filter(pl.col(platform))

                # Calculate type counts
                type_counts_dict = (
                    platform_data.get_column("type")
                    .value_counts()
                    .sort("count", descending=True)
                    .to_dict(as_series=False)
                )

                # Store counts
                for sv_type, count in zip(
                    type_counts_dict["type"], type_counts_dict["count"]
                ):
                    type_counts[(sample_id, platform.lower())][sv_type] = count
                    type_stats[platform.lower()][sv_type].append(count)

        # Create counts DataFrame
        counts_data = [
            {"sample_id": sample_id, "platform": platform, **counts}
            for (sample_id, platform), counts in type_counts.items()
        ]
        df_sv_type_counts = pl.DataFrame(counts_data)

        # Calculate statistics
        stats_data = [
            {
                "platform": platform,
                "sv_type": sv_type,
                **calculate_sv_type_stats(counts).__dict__,
            }
            for platform, sv_types in type_stats.items()
            for sv_type, counts in sv_types.items()
        ]
        df_sv_type_stats = pl.DataFrame(stats_data)

        # Print summary
        print("\nSV Type Counts Summary:")
        print("=" * 40)
        with pl.Config(tbl_rows=len(df_sv_type_counts)):
            display(df_sv_type_counts)

        print("\nSV Type Statistics Summary:")
        print("=" * 40)
        display(df_sv_type_stats)

        return df_sv_type_counts, df_sv_type_stats

    except Exception as e:
        print(f"Error analyzing SV types: {str(e)}")
        raise


sv_type_counts_df, sv_type_stats_df = analyze_sv_types(sv_data_df)
SV Type Counts Summary:
========================================
shape: (28, 8)
sample_idplatformINSDELSTRBNDINVDUP
strstri64i64i64i64i64i64
"A097_92""ont"12668105242320138
"A097_92""illumina"32605578null7651null
"A157_02""ont"166821359733752117
"A157_02""illumina"32585560null7881null
"A079_07""ont"1158793651815135
"A079_07""illumina"33535539null7251null
"A154_06""ont"17361140813466269
"A154_06""illumina"30595170null666nullnull
"A153_01""ont"163301346526632110
"A153_01""illumina"32225463null8011null
"A081_91""ont"89787301111085
"A081_91""illumina"34725644null8111null
"A153_06""ont"187051516633983218
"A153_06""illumina"30885429null7021null
"A154_04""ont"16385133483358278
"A154_04""illumina"32335670null7341null
"A085_00""ont"96627898121785
"A085_00""illumina"30505313null7321null
"A162_09""ont"1933215542361044026
"A162_09""illumina"32745493null7421null
"A048_09""ont"13394111602128216
"A048_09""illumina"34355715null769nullnull
"A046_12""ont"1152996251816198
"A046_12""illumina"32625392null675nullnull
"A149_01""ont"152591246828502212
"A149_01""illumina"30835425null6941null
"A160_96""ont"1911315504331243617
"A160_96""illumina"31725491null7731null
SV Type Statistics Summary:
========================================
shape: (10, 5)
platformsv_typemeanmedianstd_dev
strstrf64f64f64
"ont""INS"14784.64285715794.53352.940769
"ont""DEL"12074.57142912908.02675.388322
"ont""STR"25.64285727.08.21677
"ont""BND"53.14285754.035.981855
"ont""INV"21.92857121.09.361504
"ont""DUP"11.08.56.047432
"illumina""DEL"5491.5714295492.0139.839047
"illumina""INS"3230.0714293245.5127.395035
"illumina""BND"741.214286738.044.070039
"illumina""INV"1.01.00.0
In [47]:
def plot_sv_types(
    sv_data_df: pl.DataFrame,
    figsize: Tuple[int, int] = (12, 8),
    dpi: int = 300,
    gs: Optional[gridspec.GridSpec] = None,
) -> Optional[plt.Figure]:
    """
    Plot structural variant types by sample for both long-read and short-read data.

    Args:
        sv_data_df: DataFrame containing SV data with columns 'sample_id', 'type', 'ONT', and 'Illumina'
        figsize: Figure size as (width, height)
        dpi: Figure resolution
        gs: Optional GridSpec for plotting within a larger figure

    Returns:
        Figure object if created independently (gs=None)

    Raises:
        ValueError: If required columns are missing or data is invalid
    """
    try:
        required_cols = {"sample_id", "type", "ONT", "Illumina"}
        if not all(col in sv_data_df.columns for col in required_cols):
            missing = required_cols - set(sv_data_df.columns)
            raise ValueError(f"Missing required columns: {missing}")

        # Create figure and subplots
        if gs is None:
            fig = plt.figure(figsize=figsize, dpi=dpi)
            gs_local = gridspec.GridSpec(2, 1, figure=fig)
        else:
            fig = plt.gcf()
            gs_local = gs

        # Platform configuration
        platform_config = {
            "ONT": {"title": "Long-read", "position": 0},
            "Illumina": {"title": "Short-read", "position": 1},
        }

        # Map sample IDs to anonymised sample IDs using nanoplot_qc_metrics_df
        sample_map = dict(
            zip(
                nanoplot_qc_metrics_df.get_column("sample").to_list(),
                nanoplot_qc_metrics_df.get_column("anonymised_sample").to_list(),
            )
        )

        # Compute globally sorted variant types (alphabetically)
        all_variant_types = sorted(sv_data_df.get_column("type").unique().to_list())

        # Loop over platforms
        for platform, config in platform_config.items():
            rows, cols = gs_local.get_geometry()
            if rows == 1:
                ax = fig.add_subplot(gs_local[0, config["position"]])
            else:
                ax = fig.add_subplot(gs_local[config["position"], 0])

            # Filter and prepare data for platform using Polars
            platform_data = (
                sv_data_df.filter(pl.col(platform))
                .with_columns(
                    pl.col("sample_id").replace(sample_map).alias("anonymised_sample")
                )
                .group_by(["anonymised_sample", "type"])
                .agg(pl.len().alias("count"))
                .pivot(values="count", index="anonymised_sample", on="type")
                .fill_null(0)
                .with_columns(
                    pl.col("anonymised_sample")
                    .str.extract(r"(\d+)$")
                    .cast(pl.Int64)
                    .alias("sample_number")
                )
                .sort("sample_number")
                .drop("sample_number")
            )

            # Ensure all variant type columns exist and are ordered alphabetically
            for vt in all_variant_types:
                if vt not in platform_data.columns:
                    platform_data = platform_data.with_columns(pl.lit(0).alias(vt))
            platform_data = platform_data.select(
                ["anonymised_sample"] + all_variant_types
            )

            # Plot stacked bar chart directly using Matplotlib
            x = platform_data.get_column("anonymised_sample").to_list()
            bottom = np.zeros(len(x))
            for vt in all_variant_types:
                counts = platform_data.get_column(vt).to_numpy()
                ax.bar(x, counts, bottom=bottom, label=vt)
                bottom += counts

            ax.set_title(f"SV Types by Sample - {config['title']}")
            ax.set_xlabel("Sample")
            ax.set_ylabel("Number of SVs")
            plt.setp(ax.get_xticklabels(), rotation=45, ha="right")
            locs, labels = ax.get_xticks(), ax.get_xticklabels()
            ax.set_xticks([loc + 0.2 for loc in locs])
            if gs is None and config["position"] == 0:
                legend = ax.legend(title="SV Type", bbox_to_anchor=(1, 1))
                legend.get_title().set_weight("bold")
            elif gs is not None and config["position"] == 1:
                legend = ax.legend(title="SV Type", bbox_to_anchor=(1.05, 1.05))
                legend.get_title().set_weight("bold")

        if gs is None:
            plt.tight_layout()
            return fig
        return None

    except Exception as e:
        logger.error(f"Error plotting SV types: {str(e)}")
        raise


sv_types_plot = plot_sv_types(sv_data_df)
No description has been provided for this image

Combined Plots¶

In [48]:
def create_combined_sv_analysis_plot(
    sv_data_df: pl.DataFrame,
    sv_counts_df: pl.DataFrame,
    figsize: Tuple[int, int] = (12, 8),
    dpi: int = 300,
) -> plt.Figure:
    """
    Create a combined figure showing SV analysis plots in 2x2 grid.

    Args:
        sv_data_df: DataFrame containing structural variant data
        sv_counts_df: DataFrame containing SV count data
        figsize: Figure size as (width, height)
        dpi: Figure resolution

    Returns:
        Combined figure object
    """
    try:
        fig = plt.figure(figsize=figsize, dpi=dpi)
        gs = fig.add_gridspec(2, 2)

        # Plot SV Calls per Platform (A)
        plot_sv_counts(
            sv_counts_df,
            gs=gridspec.GridSpecFromSubplotSpec(1, 1, subplot_spec=gs[0, 0]),
        )

        # Plot SV Size Distribution (B)
        plot_sv_size_distributions(
            sv_data_df, gs=gridspec.GridSpecFromSubplotSpec(1, 1, subplot_spec=gs[0, 1])
        )

        # Plot Long Read and Short Read SV Types (C & D)
        plot_sv_types(
            sv_data_df, gs=gridspec.GridSpecFromSubplotSpec(1, 2, subplot_spec=gs[1, :])
        )

        # Add panel labels
        for i, label in enumerate(["A", "B", "C", "D"]):
            ax = fig.axes[i]
            ax.text(
                -0.1,
                1.05,
                label,
                transform=ax.transAxes,
                fontsize=12,
                fontweight="bold",
                va="top",
            )

        fig.set_constrained_layout(True)
        return fig

    except Exception as e:
        logger.error(f"Error creating combined SV plot: {str(e)}")
        raise


combined_sv_analysis_plot = create_combined_sv_analysis_plot(sv_data_df, sv_counts_df)
No description has been provided for this image
In [49]:
@dataclass
class SVTypeConfig:
    """Configuration for SV type display names and plotting settings."""

    name: str
    use_log_scale: bool = True


def plot_sv_size_distribution_by_type(
    sv_data_df: pl.DataFrame,
    figsize: Tuple[int, int] = (12, 8),
    dpi: int = 300,
) -> plt.Figure:
    """
    Create a combined figure showing size distribution of structural variants by type
    for both long-read and short-read data in a 3x2 grid.

    Args:
        sv_data_df: DataFrame containing SV data with columns 'type', 'length', 'ONT', and 'Illumina'
        figsize: Figure size as (width, height)
        dpi: Figure resolution

    Returns:
        Combined figure object

    Raises:
        ValueError: If required columns are missing or data is invalid
    """
    try:
        required_cols = {"type", "length", "ONT", "Illumina"}
        if not all(col in sv_data_df.columns for col in required_cols):
            missing = required_cols - set(sv_data_df.columns)
            raise ValueError(f"Missing required columns: {missing}")

        # Configuration
        sv_types_config = {
            "INS": SVTypeConfig("Insertion"),
            "DEL": SVTypeConfig("Deletion"),
            "DUP": SVTypeConfig("Duplication"),
            "INV": SVTypeConfig("Inversion"),
            "BND": SVTypeConfig("Breakend", use_log_scale=False),
            "STR": SVTypeConfig("Short Tandem Repeat", use_log_scale=False),
        }

        platform_config = {
            "ONT": ("long-read", 0.7),
            "Illumina": ("short-read", 0.7),
        }

        # Create main figure and GridSpec
        fig = plt.figure(figsize=figsize, dpi=dpi)
        gs = fig.add_gridspec(3, 2)

        # Create subplots
        for idx, (sv_type, config) in enumerate(sv_types_config.items()):
            ax = fig.add_subplot(gs[idx // 2, idx % 2])

            for platform, (label, alpha) in platform_config.items():
                sv_data = sv_data_df.filter(
                    (pl.col("type") == sv_type) & pl.col(platform)
                )

                if not sv_data.is_empty():
                    sns.histplot(
                        data=sv_data,
                        x="length",
                        kde=True,
                        log_scale=config.use_log_scale,
                        stat="density",
                        ax=ax,
                        alpha=alpha,
                        label=label,
                    )

            ax.set_title(f"{config.name} Size Distribution")
            ax.set_xlabel(f"SV Size (bp)")
            ax.set_ylabel("Density")

            # Add panel labels
            ax.text(
                -0.1,
                1.05,
                chr(65 + idx),  # A, B, C, D, E, F
                transform=ax.transAxes,
                fontsize=12,
                fontweight="bold",
                va="top",
            )

            if idx == 0:
                legend = ax.legend(title="Technology")
                legend.get_title().set_weight("bold")

        fig.set_constrained_layout(True)
        return fig

    except Exception as e:
        logger.error(f"Error plotting SV size distribution: {str(e)}")
        raise


sv_size_dist_plot = plot_sv_size_distribution_by_type(sv_data_df)
No description has been provided for this image

4. SV Chromosomal Distribution¶

In [50]:
@dataclass
class ChromosomeData:
    """Data class for chromosome lengths."""

    lengths: Dict[str, int] = field(
        default_factory=lambda: {
            "chr1": 248956422,
            "chr2": 242193529,
            "chr3": 198295559,
            "chr4": 190214555,
            "chr5": 181538259,
            "chr6": 170805979,
            "chr7": 159345973,
            "chr8": 145138636,
            "chr9": 138394717,
            "chr10": 133797422,
            "chr11": 135086622,
            "chr12": 133275309,
            "chr13": 114364328,
            "chr14": 107043718,
            "chr15": 101991189,
            "chr16": 90338345,
            "chr17": 83257441,
            "chr18": 80373285,
            "chr19": 58617616,
            "chr20": 64444167,
            "chr21": 46709983,
            "chr22": 50818468,
            "chrX": 156040895,
            "chrY": 57227415,
        }
    )

    @property
    def ordered_chroms(self) -> list[str]:
        """Return chromosomes in proper order."""
        return [f"chr{i}" for i in range(1, 23)] + ["chrX", "chrY"]


def normalize_by_chrom_length(
    chrom_distribution_df: pl.DataFrame, chrom_data: ChromosomeData
) -> pl.DataFrame:
    """
    Normalize SV counts by chromosome length.

    Args:
        chrom_distribution_df: DataFrame with chromosome distribution
        chrom_data: ChromosomeData instance with chromosome lengths

    Returns:
        Normalized DataFrame

    Raises:
        ValueError: If chromosome data is missing
    """
    try:
        # Create a dictionary for chromosome lengths in millions of base pairs
        chrom_lengths_mb = {k: v / 1e6 for k, v in chrom_data.lengths.items()}

        # Convert to LazyFrame for more efficient operations
        normalized = chrom_distribution_df.lazy()

        # Create normalized columns using replace_strict instead of map_elements
        normalized = normalized.with_columns(
            [
                (
                    pl.col("ont") / pl.col("chrom").replace_strict(chrom_lengths_mb)
                ).alias("ont"),
                (
                    pl.col("illumina")
                    / pl.col("chrom").replace_strict(chrom_lengths_mb)
                ).alias("illumina"),
            ]
        )

        return normalized.collect()

    except Exception as e:
        logger.error(f"Error normalizing chromosome distribution: {str(e)}")
        raise


def analyze_chrom_distribution(sv_data_df: pl.DataFrame) -> pl.DataFrame:
    """
    Analyze the distribution of structural variants across chromosomes for ONT and Illumina platforms.

    Args:
        sv_data_df: DataFrame containing SV data with ONT and Illumina boolean columns
                   and a 'chrom' column

    Returns:
        DataFrame with chromosome distribution counts for both platforms

    Raises:
        ValueError: If required columns are missing
    """
    try:
        # Define valid chromosomes
        valid_chroms = [f"chr{i}" for i in range(1, 23)] + ["chrX", "chrY"]

        # Create counts for ONT data
        ont_counts = (
            sv_data_df.filter(pl.col("ONT") == True)
            .group_by("chrom")
            .len()
            .with_columns(pl.col("len").alias("ont"))
            .drop("len")
        )

        # Create counts for Illumina data
        illumina_counts = (
            sv_data_df.filter(pl.col("Illumina") == True)
            .group_by("chrom")
            .len()
            .with_columns(pl.col("len").alias("illumina"))
            .drop("len")
        )

        # Join the counts and fill missing values with 0
        result = (
            pl.DataFrame({"chrom": valid_chroms})
            .join(ont_counts, on="chrom", how="left")
            .join(illumina_counts, on="chrom", how="left")
            .with_columns([pl.col("ont").fill_null(0), pl.col("illumina").fill_null(0)])
        )

        return result

    except Exception as e:
        logger.error(f"Error analyzing chromosome distribution: {str(e)}")
        raise


def plot_chrom_distribution(
    chrom_distribution_df: pl.DataFrame,
    figsize: Tuple[int, int] = (12, 6),
    dpi: int = 300,
    gs: Optional[gridspec.GridSpec] = None,
) -> Optional[plt.Figure]:
    """
    Plot chromosome distribution of structural variants.

    Args:
        chrom_distribution_df: DataFrame with chromosome distribution
        figsize: Figure size as (width, height)
        dpi: Figure resolution
        gs: Optional GridSpec for plotting within a larger figure

    Returns:
        Figure object if created independently (gs=None)

    Raises:
        ValueError: If required data is missing
    """
    try:
        chrom_data = ChromosomeData()
        normalized_df = normalize_by_chrom_length(chrom_distribution_df, chrom_data)

        total_svs = normalized_df.select(
            [pl.col("ont").sum(), pl.col("illumina").sum()]
        )

        normalized_df = normalized_df.with_columns(
            [
                (pl.col("ont") / total_svs.item(0, 0) * 100).alias("long-read"),
                (pl.col("illumina") / total_svs.item(0, 1) * 100).alias("short-read"),
            ]
        ).drop(["ont", "illumina"])

        plot_data = normalized_df.unpivot(
            index=["chrom"],
            on=["long-read", "short-read"],
            variable_name="Platform",
            value_name="value",
        )

        plot_data = (
            plot_data.with_columns(pl.col("chrom").cast(pl.Categorical))
            .with_columns(
                pl.col("chrom")
                .cast(pl.Categorical)
                .map_elements(
                    lambda x: (
                        chrom_data.ordered_chroms.index(x)
                        if x in chrom_data.ordered_chroms
                        else -1
                    ),
                    return_dtype=pl.Int64,
                )
                .alias("chrom_order")
            )
            .sort("chrom_order")
            .drop("chrom_order")
        )

        if gs is None:
            fig = plt.figure(figsize=figsize, dpi=dpi)
            ax = fig.add_subplot(111)
        else:
            fig = plt.gcf()
            ax = fig.add_subplot(gs)

        sns.barplot(
            data=plot_data,
            x="chrom",
            y="value",
            hue="Platform",
            order=chrom_data.ordered_chroms,
            ax=ax,
        )

        ax.set_title("Normalised Chromosomal Distribution of SVs")
        ax.set_xlabel("Chromosome")
        ax.set_ylabel("Proportion of SVs per Mb (%)")
        ax.legend(title="Technology").get_title().set_weight("bold")

        plt.setp(ax.get_xticklabels(), rotation=45, ha="right")
        locs, labels = ax.get_xticks(), ax.get_xticklabels()
        ax.set_xticks([loc + 0.18 for loc in locs])

        if gs is None:
            plt.tight_layout()
            return fig
        return None

    except Exception as e:
        logger.error(f"Error plotting chromosome distribution: {str(e)}")
        raise


chrom_distribution_df = analyze_chrom_distribution(sv_data_df)
sv_chrom_dist_plot = plot_chrom_distribution(chrom_distribution_df)
No description has been provided for this image
In [51]:
def calculate_sv_correlations(
    chrom_distribution_df: pl.DataFrame, chromosome_data: ChromosomeData
) -> Tuple[pl.DataFrame, float, float]:
    """
    Calculate Pearson correlations between chromosome length and mean SV counts.

    Args:
        chrom_distribution_df: Polars DataFrame with SV counts per chromosome.
        chromosome_data: ChromosomeData instance containing chromosome lengths.

    Returns:
        Tuple containing:
        - Polars DataFrame with correlation data.
        - Pearson correlation coefficient for ONT.
        - Pearson correlation coefficient for Illumina.
    """
    # Use ordered chromosomes from ChromosomeData
    ordered_chroms = chromosome_data.ordered_chroms

    # Calculate mean counts for each chromosome with consistent ordering
    ont_mean_count = (
        chrom_distribution_df.group_by("chrom")
        .agg(pl.col("ont").mean().alias("ont_mean"))
        .filter(pl.col("chrom").is_in(ordered_chroms))
        .sort("chrom", descending=False)
    )

    illumina_mean_count = (
        chrom_distribution_df.group_by("chrom")
        .agg(pl.col("illumina").mean().alias("illumina_mean"))
        .filter(pl.col("chrom").is_in(ordered_chroms))
        .sort("chrom", descending=False)
    )

    # Create ordered lists of chromosome lengths
    ordered_lengths = [chromosome_data.lengths[chrom] for chrom in ordered_chroms]

    # Create correlation DataFrame with consistent ordering
    corr_data = pl.DataFrame(
        {
            "chrom": ordered_chroms,
            "length": ordered_lengths,
            "ont_count": ont_mean_count["ont_mean"].to_list(),
            "illumina_count": illumina_mean_count["illumina_mean"].to_list(),
        }
    )

    # Calculate correlations using ordered data
    ont_corr, ont_p = stats.pearsonr(corr_data["length"], corr_data["ont_count"])
    illumina_corr, illumina_p = stats.pearsonr(
        corr_data["length"], corr_data["illumina_count"]
    )

    print(f"ONT correlation: {ont_corr:.2f} (p-value: {ont_p:.2e})")
    print(f"Illumina correlation: {illumina_corr:.2f} (p-value: {illumina_p:.2e})")

    return corr_data, ont_corr, illumina_corr


def plot_sv_correlations(
    corr_data: pl.DataFrame,
    ont_corr: float,
    illumina_corr: float,
    figsize: Tuple[int, int] = (12, 5),
    dpi: int = 300,
    gs: Optional[gridspec.GridSpec] = None,
) -> Optional[plt.Figure]:
    """
    Plot correlations between chromosome length and SV counts for both technologies.

    Args:
        corr_data: Polars DataFrame containing correlation data.
        ont_corr: Pearson correlation coefficient for ONT data.
        illumina_corr: Pearson correlation coefficient for Illumina data.
        figsize: Figure size as (width, height).
        dpi: Figure resolution.
        gs: Optional GridSpec for plotting within a larger figure.

    Returns:
        Figure object if created independently (gs=None).

    Raises:
        ValueError: If required correlation data is missing.
    """
    try:
        if gs is None:
            fig, ax = plt.subplots(figsize=figsize, dpi=dpi)
        else:
            fig = plt.gcf()
            ax = fig.add_subplot(gs)

        _, ont_p = stats.pearsonr(corr_data["length"], corr_data["ont_count"])
        _, illumina_p = stats.pearsonr(corr_data["length"], corr_data["illumina_count"])

        # Plot long-read data
        sns.regplot(
            x="length",
            y="ont_count",
            data=corr_data,
            ax=ax,
            label=f"long-read (r={ont_corr:.2f}, p={ont_p:.2e})",
            scatter_kws={"alpha": 0.8, "label": "long-read data points"},
            line_kws={
                "color": sns.color_palette()[0],
                "label": "long-read regression line",
            },
        )

        # Plot short-read data
        sns.regplot(
            x="length",
            y="illumina_count",
            data=corr_data,
            ax=ax,
            label=f"short-read (r={illumina_corr:.2f}, p={illumina_p:.2e})",
            scatter_kws={"alpha": 0.8, "label": "short-read data points"},
            line_kws={
                "color": sns.color_palette()[1],
                "label": "short-read regression line",
            },
        )

        ax.set_title("Structural Variant Counts vs Chromosome Length")
        ax.set_xlabel("Chromosome Length (bp)")
        ax.set_ylabel("Number of Structural Variants")

        # Create custom legend
        handles, labels = ax.get_legend_handles_labels()

        # Add line and CI band descriptions to legend
        legend_elements = [
            Line2D(
                [0],
                [0],
                color=sns.color_palette()[0],
                label=f"long-read (r={ont_corr:.2f}, p={ont_p:.2e})",
            ),
            Line2D(
                [0],
                [0],
                color=sns.color_palette()[1],
                label=f"short-read (r={illumina_corr:.2f}, p={illumina_p:.2e})",
            ),
            Patch(
                facecolor=sns.color_palette()[0], alpha=0.2, label="long-read 95% CI"
            ),
            Patch(
                facecolor=sns.color_palette()[1], alpha=0.2, label="short-read 95% CI"
            ),
        ]

        ax.legend(handles=legend_elements, title="Technology")
        ax.get_legend().get_title().set_weight("bold")

        if gs is None:
            plt.tight_layout()
            return fig
        return None

    except Exception as e:
        logger.error(f"Error plotting SV correlations: {str(e)}")
        raise


chromosome_data = ChromosomeData()

corr_data, ont_corr, illumina_corr = calculate_sv_correlations(
    chrom_distribution_df, chromosome_data
)

sv_length_corr_plot = plot_sv_correlations(corr_data, ont_corr, illumina_corr)
ONT correlation: 0.11 (p-value: 6.16e-01)
Illumina correlation: 0.05 (p-value: 8.24e-01)
No description has been provided for this image

Combined Plots¶

In [52]:
def create_combined_sv_chr_plot(
    sv_data_df: pl.DataFrame,
    figsize: Tuple[int, int] = (12, 4),  # Modified for side-by-side layout
    dpi: int = 300,
) -> plt.Figure:
    """
    Create a combined figure showing chromosome distribution and correlation plots side by side.

    Args:
        sv_data_df: DataFrame containing SV data
        figsize: Figure size as (width, height)
        dpi: Figure resolution

    Returns:
        Combined figure object
    """
    try:
        fig = plt.figure(figsize=figsize, dpi=dpi)
        gs = fig.add_gridspec(1, 2)

        plot_chrom_distribution(chrom_distribution_df, gs=gs[0])

        ax_left = plt.gcf().get_axes()[0]
        ax_left.text(
            -0.1,
            1.05,
            "A",
            transform=ax_left.transAxes,
            fontsize=12,
            fontweight="bold",
            va="top",
        )

        plot_sv_correlations(corr_data, ont_corr, illumina_corr, gs=gs[1])

        ax_right = plt.gcf().get_axes()[1]
        ax_right.text(
            -0.1,
            1.05,
            "B",
            transform=ax_right.transAxes,
            fontsize=12,
            fontweight="bold",
            va="top",
        )

        fig.set_constrained_layout(True)
        return fig

    except Exception as e:
        logger.error(f"Error creating combined SV analysis plot: {str(e)}")
        raise


combined_sv_chr_plot = create_combined_sv_chr_plot(sv_data_df)
No description has been provided for this image

5. Impact of sequencing depth on structural variants¶

In [53]:
def _fit_and_plot_asymptotic_curve(
    x: np.ndarray,
    y: np.ndarray,
    ax: plt.Axes,
    color: str = "#1f77b4",  # Default matplotlib blue
    alpha: float = 0.2,
) -> None:
    """Fit and plot asymptotic curve with confidence intervals.

    Args:
        x: Input x values (depth)
        y: Input y values (SV counts)
        ax: Matplotlib axes object to plot on
        color: Color for the curve and confidence interval. Defaults to matplotlib blue
        alpha: Transparency for confidence interval. Defaults to 0.2

    Raises:
        RuntimeError: If curve fitting fails
    """

    def asymptotic_func(x: np.ndarray, a: float, b: float, c: float) -> np.ndarray:
        """Asymptotic function for curve fitting."""
        return a - b * np.exp(-c * x)

    try:
        # Fit curve
        popt, pcov = curve_fit(
            asymptotic_func,
            x,
            y,
            p0=[np.max(y), np.max(y) - np.min(y), 0.1],
            bounds=([0, 0, 0], [np.inf, np.inf, 1]),
        )

        # Generate points for smooth curve
        x_range = np.linspace(x.min(), x.max(), 100)
        y_fit = asymptotic_func(x_range, *popt)

        ax.plot(
            x_range,
            y_fit,
            color=color,
            linestyle="-",
            linewidth=2,
            label="Line of best fit",
        )

        # Calculate and plot confidence intervals
        perr = np.sqrt(np.diag(pcov))
        n = len(x)
        dof = max(0, n - len(popt))
        t = stats.t.ppf(0.975, dof)
        y_err = np.sqrt(np.sum((y - asymptotic_func(x, *popt)) ** 2) / dof)

        ci = (
            t
            * y_err
            * np.sqrt(
                1 / n + (x_range - np.mean(x)) ** 2 / np.sum((x - np.mean(x)) ** 2)
            )
        )

        ax.fill_between(
            x_range,
            y_fit - ci,
            y_fit + ci,
            color=color,
            alpha=alpha,
            label="95% Confidence Interval",
        )

    except RuntimeError as e:
        logger.warning(f"Failed to fit asymptotic curve: {str(e)}")
        raise


def _plot_sv_depth_correlation(
    data: pl.DataFrame, metric: str, label: str, ax: plt.Axes, color: Any
) -> None:
    """Plot correlation between structural variant counts and sequencing depth.

    Args:
        data: Polars DataFrame containing depth and SV count data
        metric: Column name for SV counts (e.g., 'ONT', 'Illumina', 'Merged')
        label: Label for the plot title and legend
        ax: Matplotlib axes object to plot on
        color: Color for the scatter plot and fitted curve

    Returns:
        None

    Raises:
        RuntimeError: If asymptotic curve fitting fails
    """
    x = data.get_column("wg_mean_depth").to_numpy()
    y = data.get_column(metric).to_numpy()

    sns.scatterplot(x=x, y=y, ax=ax, color=color)

    try:
        _fit_and_plot_asymptotic_curve(x, y, ax, color=color)
    except RuntimeError:
        logger.warning(f"Could not fit asymptotic curve for {label}")

    r_value, p_value = stats.pearsonr(x, y)

    ax.set_title(f"{label}\nr = {r_value:.2f}, p = {p_value:.2e}")
    ax.set_xlabel("Whole Genome Mean Depth")
    ax.set_ylabel("Number of SV Calls")


def plot_depth_vs_sv_performance(
    wg_depth_df: pl.DataFrame,
    sv_data_df: pl.DataFrame,
    figsize: Tuple[int, int] = (12, 4),
    dpi: int = 300,
    gs: Optional[gridspec.GridSpec] = None,
) -> Optional[plt.Figure]:
    """Plot relationship between sequencing depth and SV detection performance.

    Args:
        wg_depth_df: DataFrame containing whole genome depth statistics
        sv_data_df: DataFrame containing structural variant data
        figsize: Figure dimensions. Defaults to (14, 4)
        dpi: Figure resolution. Defaults to 300
        gs: GridSpec for subplot placement. If None, creates standalone figure

    Returns:
        Optional[plt.Figure]: If gs is None, returns the figure.
        If gs is provided, returns None (plots are added as subfigures).

    Raises:
        ValueError: If required columns are missing from input DataFrames
    """
    try:
        depth_data = (
            wg_depth_df.filter(pl.col("chrom") == "chr1")
            .select(["sample", "mean"])
            .rename({"mean": "wg_mean_depth"})
        )

        sv_counts = sv_data_df.group_by("sample_id").agg(
            [pl.col("ONT").sum(), pl.col("Merged").sum()]
        )

        analysis_data = sv_counts.join(
            depth_data, left_on="sample_id", right_on="sample"
        )

        if gs is None:
            fig = plt.figure(figsize=figsize, dpi=dpi)
            gs = gridspec.GridSpec(1, 2, figure=fig)
        else:
            fig = plt.gcf()

        metrics = ["ONT", "Merged"]
        labels = ["Long-read", "Consensus"]

        palette = sns.color_palette()
        for i, (metric, label) in enumerate(zip(metrics, labels)):
            ax = plt.subplot(gs[0, i])
            _plot_sv_depth_correlation(
                analysis_data, metric, label, ax, color=palette[i]
            )
            if i == 0:  # Only add legend to the first plot
                ax.legend()

        if gs is None:
            plt.tight_layout()
            return fig
        return None

    except Exception as e:
        logger.error(f"Error plotting depth vs SV performance: {str(e)}")
        raise


sv_performance_plot = plot_depth_vs_sv_performance(wg_depth_df, sv_data_df)
No description has been provided for this image
In [54]:
def prepare_sv_depth_data(
    total_depth_df: pl.DataFrame,
    sv_data_df: pl.DataFrame,
) -> Tuple[pl.DataFrame, List[str]]:
    """
    Prepare and join SV and depth data for analysis.

    Args:
        total_depth_df: DataFrame containing total depth statistics
        sv_data_df: DataFrame containing structural variant data

    Returns:
        Tuple containing:
            - Joined DataFrame with SV and depth information
            - List of unique SV types

    Raises:
        ValueError: If required columns are missing
    """
    try:
        analysis_data = sv_data_df.join(
            total_depth_df.select(["sample", "mean_depth"]),
            left_on="sample_id",
            right_on="sample",
        )

        sv_types = sorted(analysis_data.get_column("type").unique().to_list())
        return analysis_data, sv_types

    except Exception as e:
        logger.error(f"Error preparing SV depth data: {str(e)}")
        raise


def _plot_sv_size_distribution(
    platform_data: pl.DataFrame,
    sv_types: List[str],
    label: str,
    ax: plt.Axes,
    gs: gridspec.GridSpec,
) -> None:
    """
    Helper function to plot SV size distribution for a specific platform.

    Args:
        platform_data: DataFrame filtered for specific platform
        sv_types: List of unique SV types
        label: Platform label (Long-read/Short-read)
        ax: Matplotlib axes object
        gs: GridSpec for subplot placement
    """
    x = platform_data.get_column("mean_depth").to_numpy()
    y = np.log10(platform_data.get_column("length").to_numpy())
    r, p = stats.pearsonr(x, y)

    # Add correlation stats below title
    ax.set_title(f"{label} SV Size vs Depth\nr = {r:.3f}, p = {p:.2e}")

    palette = dict(zip(sv_types, sns.color_palette(n_colors=len(sv_types))))

    # Plot regression line with CI
    reg_plot = sns.regplot(
        data=platform_data,
        x="mean_depth",
        y="length",
        scatter=False,
        ax=ax,
        color="grey",
        line_kws={"linestyle": "-", "alpha": 0.8, "label": "Line of best fit"},
        ci=95,
    )

    # Get the CI lines for legend
    ci_lines = [line for line in ax.lines if line != reg_plot.lines[0]]
    if ci_lines:
        ci_lines[0].set_label("95% CI")

    # Then plot the scatter points on top
    sns.scatterplot(
        data=platform_data,
        x="mean_depth",
        y="length",
        hue="type",
        hue_order=sv_types,
        palette=palette,
        ax=ax,
        alpha=0.6,
    )

    ax.set_yscale("log")
    ax.set_xlabel("Whole Genome Mean Depth")
    ax.set_ylabel("SV Size (bp)")


def plot_sv_size_vs_depth(
    analysis_data: pl.DataFrame,
    sv_types: List[str],
    figsize: Tuple[int, int] = (12, 4),
    dpi: int = 300,
    gs: Optional[gridspec.GridSpec] = None,
) -> Optional[plt.Figure]:
    """
    Plot relationship between SV sizes and sequencing depth.

    Args:
        analysis_data: Prepared DataFrame containing joined SV and depth data
        sv_types: List of unique SV types
        figsize: Figure dimensions. Defaults to (8, 4)
        dpi: Figure resolution. Defaults to 300
        gs: GridSpec for subplot placement. If None, creates standalone figure

    Returns:
        Optional[plt.Figure]: If gs is None, returns the figure.
        If gs is provided, returns None (plots are added as subfigures).
    """
    try:
        if gs is None:
            fig = plt.figure(figsize=figsize, dpi=dpi)
            gs = gridspec.GridSpec(1, 1, figure=fig)
        else:
            fig = plt.gcf()

        # Removed Illumina from platforms list
        platforms = [("ONT", "Long-read")]

        for idx, (platform, label) in enumerate(platforms):
            ax = plt.subplot(gs[0, idx])
            _plot_sv_size_distribution(
                analysis_data.filter(pl.col(platform)),
                sv_types,
                label,
                ax,
                gs,
            )

            from matplotlib.lines import Line2D

            scatter_handles, scatter_labels = ax.get_legend_handles_labels()

            line_of_best_fit = Line2D(
                [],
                [],
                color="grey",
                linestyle="-",
                alpha=0.8,
                label="Line of best fit",
            )

            ci_patch = Patch(color="grey", alpha=0.2, label="95% CI")

            handles = scatter_handles + [line_of_best_fit, ci_patch]
            labels = scatter_labels + ["Line of best fit", "95% CI"]

            legend = ax.legend(
                handles,
                labels,
                title="SV Type",
                loc="upper left",
                bbox_to_anchor=(1.05, 1),
            )
            legend.get_title().set_weight("bold")

        if gs is None:
            plt.tight_layout()
            return fig
        return None

    except Exception as e:
        logger.error(f"Error plotting SV size vs depth: {str(e)}")
        raise


sv_depth_data, sv_types = prepare_sv_depth_data(total_depth_df, sv_data_df)

size_depth_plot = plot_sv_size_vs_depth(
    sv_depth_data,
    sv_types,
)
No description has been provided for this image

Combined Plots¶

In [55]:
def create_combined_sv_depth_plot(
    wg_depth_df: pl.DataFrame,
    total_depth_df: pl.DataFrame,
    sv_data_df: pl.DataFrame,
    figsize: Tuple[int, int] = (12, 8),
    dpi: int = 300,
) -> plt.Figure:
    """
    Create a combined figure showing SV calls vs depth and SV size vs depth analyses.

    Args:
        wg_depth_df: DataFrame containing whole genome depth statistics
        total_depth_df: DataFrame containing total depth statistics
        sv_data_df: DataFrame containing structural variant data
        figsize: Figure size as (width, height)
        dpi: Figure resolution

    Returns:
        Combined figure object

    Raises:
        ValueError: If required columns are missing from input DataFrames
    """
    try:
        fig = plt.figure(figsize=figsize, dpi=dpi)
        gs = fig.add_gridspec(3, 1, height_ratios=[1, 0.1, 1])

        gs_top = gridspec.GridSpecFromSubplotSpec(1, 2, subplot_spec=gs[0])
        gs_bottom = gridspec.GridSpecFromSubplotSpec(1, 1, subplot_spec=gs[2])

        # Plot SV calls vs depth in top row
        plot_depth_vs_sv_performance(wg_depth_df, sv_data_df, gs=gs_top)

        # Plot SV size vs depth in bottom row
        plot_sv_size_vs_depth(sv_depth_data, sv_types, gs=gs_bottom)

        axes = plt.gcf().get_axes()

        fig.text(
            0.5,
            1.01,
            "SV Calls vs Whole Genome Mean Depth",
            ha="center",
            va="center",
        )

        for idx, ax in enumerate(axes[:2]):
            ax.text(
                -0.1,
                1.05,
                chr(65 + idx),  # A, B
                transform=ax.transAxes,
                fontsize=12,
                fontweight="bold",
                va="top",
            )

        for idx, ax in enumerate(axes[2:]):
            ax.text(
                -0.1,
                1.05,
                chr(67 + idx),  # C
                transform=ax.transAxes,
                fontsize=12,
                fontweight="bold",
                va="top",
            )

        fig.set_constrained_layout(True)
        return fig
    except Exception as e:
        logger.error(f"Error creating combined SV analysis plot: {str(e)}")
        raise


combined_sv_analysis_plot = create_combined_sv_depth_plot(
    wg_depth_df, total_depth_df, sv_data_df
)
No description has been provided for this image